mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[Serving] Support FastDeploy XPU Triton Server (#1994)
* [patchelf] fix patchelf error for inference xpu * [serving] add xpu dockerfile and support fd server * [serving] add xpu dockerfile and support fd server * [Serving] support XPU + Tritron * [Serving] support XPU + Tritron * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] add comments for xpu tritron dockerfile * [Doruntime] fix xpu infer error * [Doruntime] fix xpu infer error * [XPU] update xpu dockerfile * add xpu triton server docs * add xpu triton server docs * add xpu triton server docs * add xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -43,4 +43,8 @@ examples/vision/tests_quantize
|
|||||||
fastdeploy/LICENSE
|
fastdeploy/LICENSE
|
||||||
fastdeploy/ThirdPartyNotices.txt
|
fastdeploy/ThirdPartyNotices.txt
|
||||||
FastDeployCSharp.cmake
|
FastDeployCSharp.cmake
|
||||||
python/fastdeploy/code_version.py
|
python/fastdeploy/code_version.py
|
||||||
|
*.pdmodel
|
||||||
|
*.pdiparams
|
||||||
|
*.pdiparams.info
|
||||||
|
log.txt
|
@@ -10,7 +10,7 @@ input [
|
|||||||
name: "inputs"
|
name: "inputs"
|
||||||
# input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING
|
# input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING
|
||||||
data_type: TYPE_FP32
|
data_type: TYPE_FP32
|
||||||
# input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w]
|
# input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w]
|
||||||
dims: [ 3, 224, 224 ]
|
dims: [ 3, 224, 224 ]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -31,6 +31,7 @@ instance_group [
|
|||||||
count: 1
|
count: 1
|
||||||
# Use GPU, CPU inference option is:KIND_CPU
|
# Use GPU, CPU inference option is:KIND_CPU
|
||||||
kind: KIND_GPU
|
kind: KIND_GPU
|
||||||
|
# kind: KIND_CPU
|
||||||
# The instance is deployed on the 0th GPU card
|
# The instance is deployed on the 0th GPU card
|
||||||
gpus: [0]
|
gpus: [0]
|
||||||
}
|
}
|
||||||
@@ -58,3 +59,32 @@ optimization {
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
# instance_group [
|
||||||
|
# {
|
||||||
|
# # The number of instances is 1
|
||||||
|
# count: 1
|
||||||
|
# # Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
# # kind: KIND_GPU
|
||||||
|
# kind: KIND_CPU
|
||||||
|
# # The instance is deployed on the 0th GPU card
|
||||||
|
# # gpus: [0]
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# optimization {
|
||||||
|
# execution_accelerators {
|
||||||
|
# cpu_execution_accelerator: [{
|
||||||
|
# name: "paddle_xpu",
|
||||||
|
# parameters { key: "cpu_threads" value: "4" }
|
||||||
|
# parameters { key: "use_paddle_log" value: "1" }
|
||||||
|
# parameters { key: "kunlunxin_id" value: "0" }
|
||||||
|
# parameters { key: "l3_workspace_size" value: "62914560" }
|
||||||
|
# parameters { key: "locked" value: "0" }
|
||||||
|
# parameters { key: "autotune" value: "1" }
|
||||||
|
# parameters { key: "precision" value: "int16" }
|
||||||
|
# parameters { key: "adaptive_seqlen" value: "0" }
|
||||||
|
# parameters { key: "enable_multi_stream" value: "0" }
|
||||||
|
# parameters { key: "gm_default_size" value: "0" }
|
||||||
|
# }]
|
||||||
|
# }}
|
@@ -204,7 +204,15 @@ bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool Runtime::Infer() {
|
bool Runtime::Infer() {
|
||||||
bool result = backend_->Infer(input_tensors_, &output_tensors_, false);
|
bool result = false;
|
||||||
|
if (option.device == Device::KUNLUNXIN) {
|
||||||
|
// FDTensor SetExternalData is not support for Device::KUNLUNXIN
|
||||||
|
// now, so, we need to set copy_to_fd as 'true'.
|
||||||
|
result = backend_->Infer(input_tensors_, &output_tensors_, true);
|
||||||
|
} else {
|
||||||
|
result = backend_->Infer(input_tensors_, &output_tensors_, false);
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& tensor : output_tensors_) {
|
for (auto& tensor : output_tensors_) {
|
||||||
tensor.device_id = option.device_id;
|
tensor.device_id = option.device_id;
|
||||||
}
|
}
|
||||||
|
@@ -26,6 +26,7 @@ def process_paddle_inference(paddle_inference_so_file):
|
|||||||
rpaths = [
|
rpaths = [
|
||||||
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
|
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
|
||||||
"$ORIGIN/../../third_party/install/mklml/lib/",
|
"$ORIGIN/../../third_party/install/mklml/lib/",
|
||||||
|
"$ORIGIN/../../third_party/install/xpu/lib/",
|
||||||
"$ORIGIN/../../../tensorrt/lib/"
|
"$ORIGIN/../../../tensorrt/lib/"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
47
serving/Dockerfile_xpu
Normal file
47
serving/Dockerfile_xpu
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
ARG http_proxy
|
||||||
|
ARG https_proxy
|
||||||
|
ARG no_proxy
|
||||||
|
|
||||||
|
FROM paddlepaddle/fastdeploy:21.10-cpu-only-min
|
||||||
|
|
||||||
|
ENV TZ=Asia/Shanghai \
|
||||||
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
|
http_proxy=$http_proxy \
|
||||||
|
https_proxy=$https_proxy \
|
||||||
|
no_proxy=$no_proxy
|
||||||
|
|
||||||
|
# Note: Here, use nightly built of paddle for xpu tritron server image
|
||||||
|
# to avoid the so confilcts between paddle and fastdeploy-python.
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends apt-utils libgomp1 ffmpeg libsm6 libxext6 vim wget \
|
||||||
|
&& python3 -m pip install -U pip \
|
||||||
|
&& python3 -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html \
|
||||||
|
&& python3 -m pip install paddlenlp fast-tokenizer-python
|
||||||
|
|
||||||
|
COPY python/dist/*.whl /opt/fastdeploy/
|
||||||
|
RUN python3 -m pip install /opt/fastdeploy/*.whl \
|
||||||
|
&& rm -rf /opt/fastdeploy/*.whl
|
||||||
|
|
||||||
|
COPY serving/build/libtriton_fastdeploy.so /opt/tritonserver/backends/fastdeploy/
|
||||||
|
COPY build/fastdeploy_install /opt/fastdeploy/
|
||||||
|
COPY benchmark/cpp /opt/fastdeploy/benchmark/cpp
|
||||||
|
|
||||||
|
RUN mv /opt/tritonserver/bin/tritonserver /opt/tritonserver/bin/fastdeployserver
|
||||||
|
ENV LD_LIBRARY_PATH="/opt/fastdeploy/lib:/opt/fastdeploy/third_libs/install/opencv/lib64:/opt/fastdeploy/third_libs/install/paddle2onnx/lib:/opt/fastdeploy/third_libs/install/paddle_inference/paddle/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mkldnn/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mklml/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/xpu/lib:$LD_LIBRARY_PATH"
|
||||||
|
# unset proxy
|
||||||
|
ENV http_proxy=
|
||||||
|
ENV https_proxy=
|
||||||
|
ENV no_proxy=
|
@@ -65,6 +65,19 @@ cd ../
|
|||||||
docker build -t paddlepaddle/fastdeploy:x.y.z-ipu-only-21.10 -f serving/Dockerfile_ipu .
|
docker build -t paddlepaddle/fastdeploy:x.y.z-ipu-only-21.10 -f serving/Dockerfile_ipu .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 制作XPU镜像
|
||||||
|
|
||||||
|
```
|
||||||
|
# 进入serving目录执行脚本编译fastdeploy和服务化的backend
|
||||||
|
cd serving
|
||||||
|
bash scripts/build_fd_xpu.sh
|
||||||
|
|
||||||
|
# 退出到FastDeploy主目录,制作镜像
|
||||||
|
# x.y.z为FastDeploy版本号,可根据情况自己确定。比如: 1.0.6
|
||||||
|
cd ../
|
||||||
|
docker build -t paddlepaddle/fastdeploy:x.y.z-xpu-21.10 -f serving/Dockerfile_xpu .
|
||||||
|
```
|
||||||
|
|
||||||
## 非镜像方式编译
|
## 非镜像方式编译
|
||||||
|
|
||||||
- [FastDeploy Serving CentOS编译教程](./compile_without_docker_centos.md)
|
- [FastDeploy Serving CentOS编译教程](./compile_without_docker_centos.md)
|
||||||
|
@@ -112,6 +112,37 @@ optimization {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 配置使用Paddle+XPU引擎
|
||||||
|
```
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
# XPU推理配置通过CPU Execution启动, 配合KIND_CPU使用
|
||||||
|
cpu_execution_accelerator: [
|
||||||
|
{
|
||||||
|
name: "paddle_xpu",
|
||||||
|
# CPU相关配置
|
||||||
|
# cpu_threads: CPU计算线程数
|
||||||
|
# use_paddle_log: 开启paddle log信息
|
||||||
|
parameters { key: "cpu_threads" value: "4" }
|
||||||
|
parameters { key: "use_paddle_log" value: "0" }
|
||||||
|
# XPU相关配置
|
||||||
|
# kunlunxin_id: 使用的XPU卡号
|
||||||
|
# l3_workspace_size: L3缓存size
|
||||||
|
parameters { key: "kunlunxin_id" value: "0" }
|
||||||
|
parameters { key: "l3_workspace_size" value: "0xfffc00" }
|
||||||
|
parameters { key: "locked" value: "0" }
|
||||||
|
parameters { key: "autotune" value: "1" }
|
||||||
|
parameters { key: "precision" value: "int16" }
|
||||||
|
parameters { key: "adaptive_seqlen" value: "0" }
|
||||||
|
parameters { key: "enable_multi_stream" value: "0" }
|
||||||
|
parameters { key: "gm_default_size" value: "0" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### 配置使用ONNXRuntime引擎
|
### 配置使用ONNXRuntime引擎
|
||||||
除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。ONNXRuntime引擎中,还可以进行如下配置,具体例子可参照[YOLOv5的Runtime配置](../../../examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt):
|
除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。ONNXRuntime引擎中,还可以进行如下配置,具体例子可参照[YOLOv5的Runtime配置](../../../examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt):
|
||||||
|
|
||||||
|
190
serving/docs/zh_CN/xpu.md
Normal file
190
serving/docs/zh_CN/xpu.md
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
# FastDeploy XPU Triton Server使用文档
|
||||||
|
FastDeploy XPU Triton Server通过Paddle Inference调用XPU进行推理,并且已经接入到 Triton Server。在FastDeploy XPU Triton Server中,使用XPU推理需要通过CPU instance_group和cpu_execution_accelerator进行配置和调用。本文档以PaddleClas为例,讲述如何把一个CPU/GPU Triton服务,改造成XPU Triton服务。
|
||||||
|
|
||||||
|
## 1. 准备服务化镜像
|
||||||
|
|
||||||
|
- 下载FastDeploy XPU Triton Server镜像
|
||||||
|
```bash
|
||||||
|
docker pull registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 # 稳定版
|
||||||
|
docker pull registry.baidubce.com/paddlepaddle/fastdeploy:0.0.0-xpu-21.10 # develop版本
|
||||||
|
```
|
||||||
|
|
||||||
|
- 下载部署示例代码
|
||||||
|
```bash
|
||||||
|
# 下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/classification/paddleclas/serving
|
||||||
|
|
||||||
|
# 下载ResNet50_vd模型文件和测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
|
||||||
|
tar -xvf ResNet50_vd_infer.tgz
|
||||||
|
|
||||||
|
# 将配置文件放入预处理目录
|
||||||
|
mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/inference_cls.yaml
|
||||||
|
|
||||||
|
# 将模型放入 models/runtime/1目录下, 并重命名为model.pdmodel和model.pdiparams
|
||||||
|
mv ResNet50_vd_infer/inference.pdmodel models/runtime/1/model.pdmodel
|
||||||
|
mv ResNet50_vd_infer/inference.pdiparams models/runtime/1/model.pdiparams
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. 启动容器
|
||||||
|
```bash
|
||||||
|
docker run -itd --name fd_xpu_server -v `pwd`/:/serving --net=host --privileged registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 验证XPU可用性
|
||||||
|
```bash
|
||||||
|
docker exec -it fd_xpu_server /bin/bash
|
||||||
|
cd /opt/fastdeploy/benchmark/cpp/build
|
||||||
|
./benchmark --model ResNet50_infer --config_path ../config/config.xpu.paddle.fp32.txt --enable_log_info
|
||||||
|
cd /serving
|
||||||
|
```
|
||||||
|
输出为:
|
||||||
|
```
|
||||||
|
I0529 11:07:46.860354 222 memory_optimize_pass.cc:222] Cluster name : batch_norm_46.tmp_2_max size: 1
|
||||||
|
--- Running analysis [ir_graph_to_program_pass]
|
||||||
|
I0529 11:07:46.889616 222 analysis_predictor.cc:1705] ======= optimize end =======
|
||||||
|
I0529 11:07:46.890262 222 naive_executor.cc:160] --- skip [feed], feed -> inputs
|
||||||
|
I0529 11:07:46.890703 222 naive_executor.cc:160] --- skip [save_infer_model/scale_0.tmp_1], fetch -> fetch
|
||||||
|
[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN.
|
||||||
|
[INFO] fastdeploy/runtime/backends/paddle/paddle_backend.cc(341)::Infer Running profiling for Runtime without H2D and D2H, Repeats: 1000, Warmup: 200
|
||||||
|
Runtime(ms): 0.706382ms.
|
||||||
|
```
|
||||||
|
显示启动的设备类型为:Device::KUNLUNXIN。FastDeploy Benchmark工具使用文档,请参考[benchmark](https://github.com/PaddlePaddle/FastDeploy/tree/develop/benchmark/cpp).
|
||||||
|
|
||||||
|
## 4. 配置Triton Model Config
|
||||||
|
```protobuf
|
||||||
|
# XPU服务化案例: examples/vision/classification/serving/models/runtime/config.pbtxt
|
||||||
|
# 将XPU部分的注释撤销,并注释掉原来的GPU设置,修改为:
|
||||||
|
# # Number of instances of the model
|
||||||
|
# instance_group [
|
||||||
|
# {
|
||||||
|
# # The number of instances is 1
|
||||||
|
# count: 1
|
||||||
|
# # Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
# kind: KIND_GPU
|
||||||
|
# # kind: KIND_CPU
|
||||||
|
# # The instance is deployed on the 0th GPU card
|
||||||
|
# gpus: [0]
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# optimization {
|
||||||
|
# execution_accelerators {
|
||||||
|
# gpu_execution_accelerator : [ {
|
||||||
|
# # use TRT engine
|
||||||
|
# name: "tensorrt",
|
||||||
|
# # use fp16 on TRT engine
|
||||||
|
# parameters { key: "precision" value: "trt_fp16" }
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# name: "min_shape"
|
||||||
|
# parameters { key: "inputs" value: "1 3 224 224" }
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# name: "opt_shape"
|
||||||
|
# parameters { key: "inputs" value: "1 3 224 224" }
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# name: "max_shape"
|
||||||
|
# parameters { key: "inputs" value: "16 3 224 224" }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }}
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
# The number of instances is 1
|
||||||
|
count: 1
|
||||||
|
# Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
# kind: KIND_GPU
|
||||||
|
kind: KIND_CPU
|
||||||
|
# The instance is deployed on the 0th GPU card
|
||||||
|
# gpus: [0]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
cpu_execution_accelerator: [{
|
||||||
|
name: "paddle_xpu",
|
||||||
|
parameters { key: "cpu_threads" value: "4" }
|
||||||
|
parameters { key: "use_paddle_log" value: "1" }
|
||||||
|
parameters { key: "kunlunxin_id" value: "0" }
|
||||||
|
parameters { key: "l3_workspace_size" value: "62914560" }
|
||||||
|
parameters { key: "locked" value: "0" }
|
||||||
|
parameters { key: "autotune" value: "1" }
|
||||||
|
parameters { key: "precision" value: "int16" }
|
||||||
|
parameters { key: "adaptive_seqlen" value: "0" }
|
||||||
|
parameters { key: "enable_multi_stream" value: "0" }
|
||||||
|
parameters { key: "gm_default_size" value: "0" }
|
||||||
|
}]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. 启动Triton服务
|
||||||
|
```bash
|
||||||
|
fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760
|
||||||
|
```
|
||||||
|
输出:
|
||||||
|
```
|
||||||
|
[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN.
|
||||||
|
.....
|
||||||
|
I0529 03:54:40.585326 385 server.cc:592]
|
||||||
|
+-------------+---------+--------+
|
||||||
|
| Model | Version | Status |
|
||||||
|
+-------------+---------+--------+
|
||||||
|
| paddlecls | 1 | READY |
|
||||||
|
| postprocess | 1 | READY |
|
||||||
|
| preprocess | 1 | READY |
|
||||||
|
| runtime | 1 | READY |
|
||||||
|
+-------------+---------+--------+
|
||||||
|
......
|
||||||
|
I0529 03:54:40.586430 385 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
|
||||||
|
I0529 03:54:40.586657 385 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
|
||||||
|
I0529 03:54:40.627382 385 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 客户端请求
|
||||||
|
在物理机器中执行以下命令,发送grpc请求并输出结果:
|
||||||
|
```bash
|
||||||
|
# 下载测试图片
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
# 安装客户端依赖
|
||||||
|
python3 -m pip install tritonclient\[all\]
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
python3 paddlecls_grpc_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
发送请求成功后,会返回json格式的检测结果并打印输出:
|
||||||
|
```bash
|
||||||
|
output_name: CLAS_RESULT
|
||||||
|
{'label_ids': [153], 'scores': [0.6858349442481995]}
|
||||||
|
```
|
||||||
|
以上测试结果为Paddle Inference Backend + XPU R200下的输出。
|
||||||
|
|
||||||
|
## 7. 容器内自测
|
||||||
|
如果是想在容器内自测,则运行以下命令:
|
||||||
|
```bash
|
||||||
|
cd /serving
|
||||||
|
# 后台挂载
|
||||||
|
nohup fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760 > log.txt 2>&1 &
|
||||||
|
# 安装客户端依赖
|
||||||
|
python3 -m pip install tritonclient\[all\]
|
||||||
|
# 发送请求
|
||||||
|
unset http_proxy
|
||||||
|
unset https_proxy
|
||||||
|
python3 paddlecls_grpc_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 8. 配置修改
|
||||||
|
|
||||||
|
当前默认配置在XPU运行Paddle Inference引擎, 如果要在CPU/GPU其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](./model_configuration.md).
|
||||||
|
|
||||||
|
## 9. 常见问题
|
||||||
|
- [如何编写客户端 HTTP/GRPC 请求](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/client.md)
|
||||||
|
- [如何编译服务化部署镜像](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/compile.md)
|
||||||
|
- [服务化部署原理及动态Batch介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/demo.md)
|
||||||
|
- [模型仓库介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/model_repository.md)
|
65
serving/scripts/build_fd_xpu.sh
Executable file
65
serving/scripts/build_fd_xpu.sh
Executable file
@@ -0,0 +1,65 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
echo "start build FD XPU library"
|
||||||
|
|
||||||
|
docker run -i --rm --name build_fd_xpu \
|
||||||
|
-v `pwd`/..:/workspace/fastdeploy \
|
||||||
|
-e "http_proxy=${http_proxy}" \
|
||||||
|
-e "https_proxy=${https_proxy}" \
|
||||||
|
-e "no_proxy=${no_proxy}" \
|
||||||
|
--network=host --privileged \
|
||||||
|
paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \
|
||||||
|
bash -c \
|
||||||
|
'export https_proxy_tmp=${https_proxy}
|
||||||
|
export http_proxy_tmp=${http_proxy}
|
||||||
|
cd /workspace/fastdeploy/python;
|
||||||
|
rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
|
||||||
|
ln -s /usr/bin/python3 /usr/bin/python;
|
||||||
|
export WITH_GPU=OFF;
|
||||||
|
export ENABLE_ORT_BACKEND=OFF;
|
||||||
|
export ENABLE_PADDLE_BACKEND=OFF;
|
||||||
|
export ENABLE_OPENVINO_BACKEND=OFF;
|
||||||
|
export ENABLE_VISION=ON;
|
||||||
|
export ENABLE_TEXT=ON;
|
||||||
|
unset http_proxy
|
||||||
|
unset https_proxy
|
||||||
|
python setup.py build;
|
||||||
|
python setup.py bdist_wheel;
|
||||||
|
cd /workspace/fastdeploy;
|
||||||
|
rm -rf build; mkdir build; cd build;
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${PWD}/fastdeploy_install -DWITH_KUNLUNXIN=ON -DENABLE_PADDLE_BACKEND=ON -DENABLE_VISION=ON -DENABLE_BENCHMARK=ON -DLIBRARY_NAME=fastdeploy_runtime;
|
||||||
|
make -j`nproc`;
|
||||||
|
make install;
|
||||||
|
cd /workspace/fastdeploy/serving;
|
||||||
|
rm -rf build; mkdir build; cd build;
|
||||||
|
export https_proxy=${https_proxy_tmp}
|
||||||
|
export http_proxy=${http_proxy_tmp}
|
||||||
|
cmake .. -DTRITON_ENABLE_GPU=OFF -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy_install -DTRITON_COMMON_REPO_TAG=r21.10 -DTRITON_CORE_REPO_TAG=r21.10 -DTRITON_BACKEND_REPO_TAG=r21.10;
|
||||||
|
make -j`nproc`;
|
||||||
|
cd /workspace/fastdeploy/benchmark/cpp;
|
||||||
|
rm -rf build; mkdir build; cd build;
|
||||||
|
unset http_proxy
|
||||||
|
unset https_proxy
|
||||||
|
cmake .. -DFASTDEPLOY_INSTALL_DIR=/workspace/fastdeploy/build/fastdeploy_install;
|
||||||
|
make -j`nproc`;
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_infer.tgz && tar -zxvf ResNet50_infer.tgz;
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/000000014439.jpg;
|
||||||
|
rm -f ResNet50_infer.tgz;
|
||||||
|
rm -rf CMakeFiles;
|
||||||
|
'
|
||||||
|
|
||||||
|
echo "build FD XPU library done"
|
171
serving/src/fastdeploy_runtime.cc
Executable file → Normal file
171
serving/src/fastdeploy_runtime.cc
Executable file → Normal file
@@ -199,6 +199,9 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
runtime_options_->UseOrtBackend();
|
runtime_options_->UseOrtBackend();
|
||||||
} else if (name == "paddle") {
|
} else if (name == "paddle") {
|
||||||
runtime_options_->UsePaddleBackend();
|
runtime_options_->UsePaddleBackend();
|
||||||
|
} else if (name == "paddle_xpu") {
|
||||||
|
// Note(qiuyanjun): use XPU via paddle inference backend.
|
||||||
|
runtime_options_->UsePaddleInferBackend();
|
||||||
} else if (name == "openvino") {
|
} else if (name == "openvino") {
|
||||||
runtime_options_->UseOpenVINOBackend();
|
runtime_options_->UseOpenVINOBackend();
|
||||||
} else if (name != "") {
|
} else if (name != "") {
|
||||||
@@ -212,44 +215,118 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
}
|
}
|
||||||
|
|
||||||
triton::common::TritonJson::Value params;
|
triton::common::TritonJson::Value params;
|
||||||
if (ea.Find("parameters", ¶ms)) {
|
if (name == "paddle_xpu") {
|
||||||
std::vector<std::string> param_keys;
|
// parse parameters for cpu host + xpu device.
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys));
|
if (ea.Find("parameters", ¶ms)) {
|
||||||
for (const auto& param_key : param_keys) {
|
std::vector<std::string> param_keys;
|
||||||
std::string value_string;
|
THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys));
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
// default settings for XPU.
|
||||||
params.MemberAsString(param_key.c_str(), &value_string));
|
int kunlunxin_id = 0;
|
||||||
if (param_key == "cpu_threads") {
|
int l3_workspace_size = 0xfffc00;
|
||||||
int cpu_thread_num;
|
bool locked = false;
|
||||||
|
bool autotune = true;
|
||||||
|
std::string autotune_file = "";
|
||||||
|
std::string precision = "int16";
|
||||||
|
bool adaptive_seqlen = false;
|
||||||
|
bool enable_multi_stream = false;
|
||||||
|
// for future use (only support lite backend now).
|
||||||
|
int gm_default_size = 0;
|
||||||
|
// common settings for cpu host.
|
||||||
|
int cpu_thread_num = -1;
|
||||||
|
bool use_paddle_log = false;
|
||||||
|
|
||||||
|
for (const auto& param_key : param_keys) {
|
||||||
|
std::string value_string;
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
ParseIntValue(value_string, &cpu_thread_num));
|
params.MemberAsString(param_key.c_str(), &value_string));
|
||||||
runtime_options_->SetCpuThreadNum(cpu_thread_num);
|
// parse common settings for cpu host.
|
||||||
} else if (param_key == "use_mkldnn") {
|
if (param_key == "cpu_threads") {
|
||||||
bool pd_enable_mkldnn;
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseIntValue(value_string, &cpu_thread_num));
|
||||||
|
runtime_options_->SetCpuThreadNum(cpu_thread_num);
|
||||||
|
} else if (param_key == "use_paddle_log") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &use_paddle_log));
|
||||||
|
runtime_options_->paddle_infer_option.enable_log_info =
|
||||||
|
use_paddle_log;
|
||||||
|
} else if (param_key == "is_clone") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &is_clone_));
|
||||||
|
} else if (param_key == "encryption_key") {
|
||||||
|
runtime_options_->SetEncryptionKey(value_string);
|
||||||
|
// parse common settings for xpu device.
|
||||||
|
} else if (param_key == "kunlunxin_id") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseIntValue(value_string, &kunlunxin_id));
|
||||||
|
} else if (param_key == "l3_workspace_size") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseIntValue(value_string, &l3_workspace_size));
|
||||||
|
} else if (param_key == "locked") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &locked));
|
||||||
|
} else if (param_key == "autotune") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &autotune));
|
||||||
|
} else if (param_key == "precision") {
|
||||||
|
precision = value_string;
|
||||||
|
} else if (param_key == "adaptive_seqlen") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &adaptive_seqlen));
|
||||||
|
} else if (param_key == "enable_multi_stream") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &enable_multi_stream));
|
||||||
|
} else if (param_key == "gm_default_size") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseIntValue(value_string, &gm_default_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// initialize xpu device settings
|
||||||
|
runtime_options_->UseKunlunXin(
|
||||||
|
kunlunxin_id, l3_workspace_size, locked, autotune,
|
||||||
|
autotune_file, precision, adaptive_seqlen, enable_multi_stream,
|
||||||
|
int64_t(gm_default_size));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// parse parameters for cpu only
|
||||||
|
if (ea.Find("parameters", ¶ms)) {
|
||||||
|
std::vector<std::string> param_keys;
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys));
|
||||||
|
for (const auto& param_key : param_keys) {
|
||||||
|
std::string value_string;
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
ParseBoolValue(value_string, &pd_enable_mkldnn));
|
params.MemberAsString(param_key.c_str(), &value_string));
|
||||||
runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn);
|
if (param_key == "cpu_threads") {
|
||||||
} else if (param_key == "use_paddle_log") {
|
int cpu_thread_num;
|
||||||
bool use_paddle_log;
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
ParseIntValue(value_string, &cpu_thread_num));
|
||||||
ParseBoolValue(value_string, &use_paddle_log));
|
runtime_options_->SetCpuThreadNum(cpu_thread_num);
|
||||||
runtime_options_->paddle_infer_option.enable_log_info =
|
} else if (param_key == "use_mkldnn") {
|
||||||
use_paddle_log;
|
bool pd_enable_mkldnn;
|
||||||
} else if (param_key == "num_streams") {
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
int num_streams;
|
ParseBoolValue(value_string, &pd_enable_mkldnn));
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn);
|
||||||
ParseIntValue(value_string, &num_streams));
|
} else if (param_key == "use_paddle_log") {
|
||||||
runtime_options_->openvino_option.num_streams = num_streams;
|
bool use_paddle_log;
|
||||||
} else if (param_key == "is_clone") {
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
ParseBoolValue(value_string, &use_paddle_log));
|
||||||
ParseBoolValue(value_string, &is_clone_));
|
runtime_options_->paddle_infer_option.enable_log_info =
|
||||||
} else if (param_key == "use_ipu") {
|
use_paddle_log;
|
||||||
// runtime_options_->UseIpu();
|
} else if (param_key == "num_streams") {
|
||||||
} else if (param_key == "encryption_key") {
|
int num_streams;
|
||||||
runtime_options_->SetEncryptionKey(value_string);
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseIntValue(value_string, &num_streams));
|
||||||
|
runtime_options_->openvino_option.num_streams = num_streams;
|
||||||
|
} else if (param_key == "is_clone") {
|
||||||
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
|
ParseBoolValue(value_string, &is_clone_));
|
||||||
|
} else if (param_key == "use_ipu") {
|
||||||
|
// runtime_options_->UseIpu();
|
||||||
|
} else if (param_key == "encryption_key") {
|
||||||
|
runtime_options_->SetEncryptionKey(value_string);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // end 'name == "paddle_xpu"'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -422,7 +499,7 @@ TRITONSERVER_Error* ModelState::LoadModel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GPU
|
// GPU
|
||||||
#ifdef TRITON_ENABLE_GPU
|
#ifdef TRITON_ENABLE_GPU
|
||||||
if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
|
if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
|
||||||
(instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) {
|
(instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) {
|
||||||
@@ -432,8 +509,9 @@ TRITONSERVER_Error* ModelState::LoadModel(
|
|||||||
runtime_options_->UseCpu();
|
runtime_options_->UseCpu();
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (runtime_options_->device != fastdeploy::Device::IPU) {
|
if ((runtime_options_->device != fastdeploy::Device::IPU) &&
|
||||||
// If Device is set to IPU, just skip CPU setting.
|
(runtime_options_->device != fastdeploy::Device::KUNLUNXIN)) {
|
||||||
|
// If Device is set to IPU/XPU, just skip CPU setting.
|
||||||
runtime_options_->UseCpu();
|
runtime_options_->UseCpu();
|
||||||
}
|
}
|
||||||
#endif // TRITON_ENABLE_GPU
|
#endif // TRITON_ENABLE_GPU
|
||||||
@@ -972,7 +1050,7 @@ void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests,
|
|||||||
SetInputTensors(total_batch_size, requests, request_count, &responses,
|
SetInputTensors(total_batch_size, requests, request_count, &responses,
|
||||||
&collector, &cuda_copy));
|
&collector, &cuda_copy));
|
||||||
|
|
||||||
// Wait for any in-flight input tensor copies to complete.
|
// Wait for any in-flight input tensor copies to complete.
|
||||||
#ifdef TRITON_ENABLE_GPU
|
#ifdef TRITON_ENABLE_GPU
|
||||||
if (cuda_copy) {
|
if (cuda_copy) {
|
||||||
cudaStreamSynchronize(CudaStream());
|
cudaStreamSynchronize(CudaStream());
|
||||||
@@ -1146,15 +1224,16 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors(
|
|||||||
const uint32_t request_count,
|
const uint32_t request_count,
|
||||||
std::vector<TRITONBACKEND_Response*>* responses) {
|
std::vector<TRITONBACKEND_Response*>* responses) {
|
||||||
// r22.12
|
// r22.12
|
||||||
BackendOutputResponder responder(
|
|
||||||
requests, request_count, responses,
|
|
||||||
model_state_->TritonMemoryManager(), model_state_->MaxBatchSize() > 0,
|
|
||||||
model_state_->EnablePinnedOutput(), CudaStream());
|
|
||||||
// r21.10
|
|
||||||
// BackendOutputResponder responder(
|
// BackendOutputResponder responder(
|
||||||
// requests, request_count, responses, StateForModel()->MaxBatchSize(),
|
// requests, request_count, responses,
|
||||||
// StateForModel()->TritonMemoryManager(),
|
// model_state_->TritonMemoryManager(), model_state_->MaxBatchSize() > 0,
|
||||||
// StateForModel()->EnablePinnedOutput(), CudaStream());
|
// model_state_->EnablePinnedOutput(), CudaStream());
|
||||||
|
|
||||||
|
// r21.10
|
||||||
|
BackendOutputResponder responder(
|
||||||
|
requests, request_count, responses, StateForModel()->MaxBatchSize(),
|
||||||
|
StateForModel()->TritonMemoryManager(),
|
||||||
|
StateForModel()->EnablePinnedOutput(), CudaStream());
|
||||||
|
|
||||||
// Use to hold string output contents
|
// Use to hold string output contents
|
||||||
bool cuda_copy = false;
|
bool cuda_copy = false;
|
||||||
|
Reference in New Issue
Block a user