diff --git a/docs/api/runtime/runtime_option.md b/docs/api/runtime/runtime_option.md index a1dbf988f..11ed76bfa 100644 --- a/docs/api/runtime/runtime_option.md +++ b/docs/api/runtime/runtime_option.md @@ -66,8 +66,7 @@ use_openvino_backend() 使用OpenVINO后端进行推理,支持CPU, 支持Paddle/ONNX模型格式 ``` -enable_paddle_mkldnn() -disable_paddle_mkldnn() +set_paddle_mkldnn() ``` 当使用Paddle Inference后端时,通过此开关开启或关闭CPU上MKLDNN推理加速,后端默认为开启 @@ -180,8 +179,7 @@ void UseOpenVINOBackend() 使用OpenVINO后端进行推理,支持CPU, 支持Paddle/ONNX模型格式 ``` -void EnablePaddleMKLDNN() -void DisablePaddleMKLDNN() +void SetPaddleMKLDNN(bool pd_mkldnn = true) ``` 当使用Paddle Inference后端时,通过此开关开启或关闭CPU上MKLDNN推理加速,后端默认为开启 diff --git a/docs/docs_en/api/runtime/runtime_option.md b/docs/docs_en/api/runtime/runtime_option.md index 52907db2b..003ad1823 100644 --- a/docs/docs_en/api/runtime/runtime_option.md +++ b/docs/docs_en/api/runtime/runtime_option.md @@ -73,8 +73,7 @@ use_openvino_backend() Inference with OpenVINO backend (CPU supported, Paddle/ONNX model format supported) ``` -enable_paddle_mkldnn() -disable_paddle_mkldnn() +set_paddle_mkldnn() ``` When using the Paddle Inference backend, this parameter determines whether the MKLDNN inference acceleration on the CPU is on or off. It is on by default. @@ -204,8 +203,7 @@ void UseOpenVINOBackend() Inference with OpenVINO backend (CPU supported, Paddle/ONNX model format supported) ``` -void EnablePaddleMKLDNN() -void DisablePaddleMKLDNN() +void SetPaddleMKLDNN(bool pd_mkldnn = true) ``` When using the Paddle Inference backend, this parameter determines whether the MKLDNN inference acceleration on the CPU is on or off. It is on by default. diff --git a/examples/vision/detection/yolov5/README.md b/examples/vision/detection/yolov5/README.md index e1930ad3b..302e77a7b 100644 --- a/examples/vision/detection/yolov5/README.md +++ b/examples/vision/detection/yolov5/README.md @@ -23,7 +23,7 @@ - [Python部署](python) - [C++部署](cpp) - +- [服务化部署](serving) ## 版本说明 diff --git a/examples/vision/detection/yolov5/serving/README.md b/examples/vision/detection/yolov5/serving/README.md index 827d4fe59..2826f6a9e 100644 --- a/examples/vision/detection/yolov5/serving/README.md +++ b/examples/vision/detection/yolov5/serving/README.md @@ -1,19 +1,53 @@ -# YOLOv5 Serving部署示例 +# YOLOv5 服务化部署示例 + +## 启动服务 ```bash -#下载yolov5模型文件和测试图片 +#下载yolov5模型文件 wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s.onnx -wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg # 将模型放入 models/infer/1目录下, 并重命名为model.onnx mv yolov5s.onnx models/infer/1/ # 拉取fastdeploy镜像 -docker pull xxx +docker pull paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 -# 启动镜像和服务 -docker run xx +# 运行容器.容器名字为 fd_serving, 并挂载当前目录为容器的 /yolov5_serving 目录 +nvidia-docker run -it --net=host --name fd_serving -v `pwd`/:/yolov5_serving paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 bash -# 客户端请求 -python yolov5_grpc_client.py +# 启动服务(不设置CUDA_VISIBLE_DEVICES环境变量,会拥有所有GPU卡的调度权限) +CUDA_VISIBLE_DEVICES=0 fastdeployserver --model-repository=models --backend-config=python,shm-default-byte-size=10485760 ``` + +服务启动成功后, 会有以下输出: +``` +...... +I0928 04:51:15.784517 206 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001 +I0928 04:51:15.785177 206 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000 +I0928 04:51:15.826578 206 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002 +``` + + +## 客户端请求 + +在物理机器中执行以下命令,发送grpc请求并输出结果 +``` +#下载测试图片 +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +#安装客户端依赖 +python3 -m pip install tritonclient\[all\] + +# 发送请求 +python3 yolov5_grpc_client.py +``` + +发送请求成功后,会返回json格式的检测结果并打印输出: +``` +output_name: detction_result +{'boxes': [[268.48028564453125, 81.05305480957031, 298.69476318359375, 169.43902587890625], [104.73116302490234, 45.66197204589844, 127.58382415771484, 93.44938659667969], [378.9093933105469, 39.75013732910156, 395.6086120605469, 84.24342346191406], [158.552978515625, 80.36149597167969, 199.18576049804688, 168.18191528320312], [414.37530517578125, 90.94805908203125, 506.3218994140625, 280.40521240234375], [364.00341796875, 56.608917236328125, 381.97857666015625, 115.96823120117188], [351.7251281738281, 42.635345458984375, 366.9103088378906, 98.04837036132812], [505.8882751464844, 114.36674499511719, 593.1248779296875, 275.99530029296875], [327.7086181640625, 38.36369323730469, 346.84991455078125, 80.89302062988281], [583.493408203125, 114.53289794921875, 612.3546142578125, 175.87353515625], [186.4706573486328, 44.941375732421875, 199.6645050048828, 61.037628173828125], [169.6158905029297, 48.01460266113281, 178.1415557861328, 60.88859558105469], [25.81019401550293, 117.19969177246094, 59.88878631591797, 152.85012817382812], [352.1452941894531, 46.71272277832031, 381.9460754394531, 106.75212097167969], [1.875, 150.734375, 37.96875, 173.78125], [464.65728759765625, 15.901412963867188, 472.512939453125, 34.11640930175781], [64.625, 135.171875, 84.5, 154.40625], [57.8125, 151.234375, 103.0, 174.15625], [165.890625, 88.609375, 527.90625, 339.953125], [101.40625, 152.5625, 118.890625, 169.140625]], 'scores': [0.8965693116188049, 0.8695310950279236, 0.8684297800064087, 0.8429877758026123, 0.8358422517776489, 0.8151364326477051, 0.8089362382888794, 0.801361083984375, 0.7947245836257935, 0.7606497406959534, 0.6325908303260803, 0.6139386892318726, 0.5906146764755249, 0.505328893661499, 0.40457233786582947, 0.3460320234298706, 0.33283042907714844, 0.3325657248497009, 0.2594234347343445, 0.25389009714126587], 'label_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 0, 24, 24, 33, 24], 'masks': [], 'contain_masks': False} +``` + +## 配置修改 + +当前默认配置在GPU上运行ONNXRuntime引擎, 如果要在CPU或其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md) diff --git a/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt b/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt index ffed1edf4..009717c2a 100644 --- a/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt +++ b/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt @@ -5,7 +5,6 @@ max_batch_size: 16 # Input configuration of the model input [ - # 第一个输入 { # input name name: "images" @@ -36,3 +35,12 @@ instance_group [ gpus: [0] } ] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use ONNXRuntime engine + name: "onnxruntime", + parameters { key: "cpu_threads" value: "2" } + }] +}} diff --git a/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py b/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py index f21991174..58b593f30 100644 --- a/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py +++ b/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py @@ -65,11 +65,6 @@ class SyncGRPCTritonRunner: """ infer_inputs = [] for idx, data in enumerate(inputs): - print("len(data):", len(data)) - print("name:", self._input_names[idx], " shape:", data.shape, - data.dtype) - #data = np.array([[x.encode('utf-8')] for x in data], - # dtype=np.object_) infer_input = InferInput(self._input_names[idx], data.shape, "UINT8") infer_input.set_data_from_numpy(data) @@ -106,7 +101,7 @@ if __name__ == "__main__": result = runner.Run([im, ]) for name, values in result.items(): print("output_name:", name) - for i in range(len(values)): - value = values[i][0] + for j in range(len(values)): + value = values[j][0] value = json.loads(value) print(value) diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 8df80ed18..023daaf74 100644 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -28,8 +28,7 @@ void BindRuntime(pybind11::module& m) { .def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend) - .def("enable_paddle_mkldnn", &RuntimeOption::EnablePaddleMKLDNN) - .def("disable_paddle_mkldnn", &RuntimeOption::DisablePaddleMKLDNN) + .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("set_paddle_mkldnn_cache_size", diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index b4c4cd858..2bbe643ae 100644 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -240,9 +240,9 @@ void RuntimeOption::UseLiteBackend() { #endif } -void RuntimeOption::EnablePaddleMKLDNN() { pd_enable_mkldnn = true; } - -void RuntimeOption::DisablePaddleMKLDNN() { pd_enable_mkldnn = false; } +void RuntimeOption::SetPaddleMKLDNN(bool pd_mkldnn) { + pd_enable_mkldnn = pd_mkldnn; +} void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) { pd_delete_pass_names.push_back(pass_name); diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index 7804c2ac4..a4d857ac3 100644 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -119,11 +119,8 @@ struct FASTDEPLOY_DECL RuntimeOption { /// Set Paddle Lite as inference backend, only support arm cpu void UseLiteBackend(); - /// Enable mkldnn while using Paddle Inference as inference backend - void EnablePaddleMKLDNN(); - - /// Disable mkldnn while using Paddle Inference as inference backend - void DisablePaddleMKLDNN(); + // set mkldnn switch while using Paddle Inference as inference backend + void SetPaddleMKLDNN(bool pd_mkldnn = true); /** * @brief Delete pass by name while using Paddle Inference as inference backend, this can be called multiple times to delete a set of passes diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 7a16844a3..3d156ef62 100644 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -85,11 +85,8 @@ class RuntimeOption: def use_lite_backend(self): return self._option.use_lite_backend() - def enable_paddle_mkldnn(self): - return self._option.enable_paddle_mkldnn() - - def disable_paddle_mkldnn(self): - return self._option.disable_paddle_mkldnn() + def set_paddle_mkldnn(self): + return self._option.set_paddle_mkldnn() def enable_paddle_log_info(self): return self._option.enable_paddle_log_info() diff --git a/serving/Dockerfile b/serving/Dockerfile new file mode 100644 index 000000000..71921dc8a --- /dev/null +++ b/serving/Dockerfile @@ -0,0 +1,48 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvcr.io/nvidia/tritonserver:21.10-py3 as full +FROM nvcr.io/nvidia/tritonserver:21.10-py3-min + +COPY --from=full /opt/tritonserver/bin/tritonserver /opt/tritonserver/bin/fastdeployserver +COPY --from=full /opt/tritonserver/lib /opt/tritonserver/lib +COPY --from=full /opt/tritonserver/include /opt/tritonserver/include +COPY --from=full /opt/tritonserver/backends/python /opt/tritonserver/backends/python + +COPY TensorRT-8.4.1.5 /opt/ + +ENV TZ=Asia/Shanghai \ + DEBIAN_FRONTEND=noninteractive \ + DCGM_VERSION=2.2.9 +RUN apt-get update \ + && apt-key del 7fa2af80 \ + && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb \ + && dpkg -i cuda-keyring_1.0-1_all.deb \ + && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub \ + && apt-get update && apt-get install -y --no-install-recommends datacenter-gpu-manager=1:2.2.9 + +RUN apt-get update \ + && apt-get install -y --no-install-recommends libre2-5 libb64-0d python3 python3-pip libarchive-dev \ + && python3 -m pip install -U pip \ + && python3 -m pip install paddlepaddle-gpu paddlenlp faster_tokenizer + +COPY python/dist/*.whl /opt/fastdeploy/ +RUN python3 -m pip install /opt/fastdeploy/*.whl \ + && rm -rf /opt/fastdeploy/*.whl + +COPY serving/build/libtriton_fastdeploy.so /opt/tritonserver/backends/fastdeploy/ +COPY build/fastdeploy-0.0.3 /opt/fastdeploy/ + +ENV LD_LIBRARY_PATH="/opt/TensorRT-8.4.1.5/lib/:/opt/fastdeploy/lib:/opt/fastdeploy/third_libs/install/onnxruntime/lib:/opt/fastdeploy/third_libs/install/paddle2onnx/lib:/opt/fastdeploy/third_libs/install/tensorrt/lib:/opt/fastdeploy/third_libs/install/paddle_inference/paddle/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mkldnn/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mklml/lib:/opt/fastdeploy/third_libs/install/openvino/runtime/lib:$LD_LIBRARY_PATH" +ENV PATH="/opt/tritonserver/bin:$PATH" diff --git a/serving/README.md b/serving/README.md new file mode 120000 index 000000000..bacd3186b --- /dev/null +++ b/serving/README.md @@ -0,0 +1 @@ +README_CN.md \ No newline at end of file diff --git a/serving/README_CN.md b/serving/README_CN.md new file mode 100644 index 000000000..5849af03f --- /dev/null +++ b/serving/README_CN.md @@ -0,0 +1,19 @@ +简体中文 | [English](README_EN.md) + +# FastDeploy 服务化部署 + +## 简介 + +FastDeploy基于[Triton Inference Server](https://github.com/triton-inference-server/server)搭建了端到端的服务化部署。底层后端使用FastDeploy高性能Runtime模块,并串联FastDeploy前后处理模块实现端到端的服务化部署。具有快速部署、使用简单、性能卓越的特性。 + +## 端到端部署示例 + +- [YOLOV5 检测任务](../examples/vision/detection/yolov5/README.md) +- [OCR ]() +- [Erinie3.0 文本分类任务]() +- [UIE ]() +- [Speech ]() + +## 高阶文档 +- [模型仓库](docs/zh_CN/model_repository.md) +- [模型配置](docs/zh_CN/model_configuration.md) diff --git a/serving/README_EN.md b/serving/README_EN.md new file mode 100644 index 000000000..b2ac1f61c --- /dev/null +++ b/serving/README_EN.md @@ -0,0 +1 @@ +English | [简体中文](README_CN.md) diff --git a/serving/docs/zh_CN/model_configuration.md b/serving/docs/zh_CN/model_configuration.md new file mode 100644 index 000000000..7a19aa8fa --- /dev/null +++ b/serving/docs/zh_CN/model_configuration.md @@ -0,0 +1,168 @@ +# 模型配置 +模型存储库中的每个模型都必须包含一个模型配置,该配置提供了关于模型的必要和可选信息。这些配置信息一般写在 *config.pbtxt* 文件中,[ModelConfig protobuf](https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto)格式。 + +## 模型通用最小配置 +详细的模型通用配置请看官网文档: [model_configuration](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md).Triton的最小模型配置必须包括: *platform* 或 *backend* 属性、*max_batch_size* 属性和模型的输入输出. + +例如一个Paddle模型,有两个输入*input0* 和 *input1*,一个输出*output0*,输入输出都是float32类型的tensor,最大batch为8.则最小的配置如下: + +``` + backend: "fastdeploy" + max_batch_size: 8 + input [ + { + name: "input0" + data_type: TYPE_FP32 + dims: [ 16 ] + }, + { + name: "input1" + data_type: TYPE_FP32 + dims: [ 16 ] + } + ] + output [ + { + name: "output0" + data_type: TYPE_FP32 + dims: [ 16 ] + } + ] +``` + +## CPU、GPU和实例个数配置 + +通过*instance_group*属性可以配置服务使用哪种硬件资源,分别部署多少个模型推理实例。 + +CPU部署例子: +``` + instance_group [ + { + # 创建两个CPU实例 + count: 2 + # 使用CPU部署 + kind: KIND_CPU + } + ] +``` + +在*GPU 0*上部署2个实例,在*GPU1*和*GPU*上分别部署1个实例 + +``` + instance_group [ + { + # 创建两个GPU实例 + count: 2 + # 使用GPU推理 + kind: KIND_GPU + # 部署在GPU卡0上 + gpus: [ 0 ] + }, + { + count: 1 + kind: KIND_GPU + # 在GPU卡1、2都部署 + gpus: [ 1, 2 ] + } + ] +``` + +### Name, Platform and Backend +模型配置中 *name* 属性是可选的。如果模型没有在配置中指定,则使用模型的目录名;如果指定了该属性,它必须要跟模型的目录名一致。 + +使用 *fastdeploy backend*,没有*platform*属性可以配置,必须配置*backend*属性为*fastdeploy*。 + +``` +backend: "fastdeploy" +``` + +### FastDeploy Backend配置 + +FastDeploy后端目前支持*cpu*和*gpu*推理,*cpu*上支持*paddle*、*onnxruntime*和*openvino*三个推理引擎,*gpu*上支持*paddle*、*onnxruntime*和*tensorrt*三个引擎。 + + +#### 配置使用Paddle引擎 +除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。Paddle引擎中,还可以进行如下配置: + +``` +optimization { + execution_accelerators { + # CPU推理配置, 配合KIND_CPU使用 + cpu_execution_accelerator : [ + { + name : "paddle" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + # 开启mkldnn加速,设置为0关闭mkldnn + parameters { key: "use_mkldnn" value: "1" } + } + ], + # GPU推理配置, 配合KIND_GPU使用 + gpu_execution_accelerator : [ + { + name : "paddle" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + # 开启mkldnn加速,设置为0关闭mkldnn + parameters { key: "use_mkldnn" value: "1" } + } + ] + } +} +``` + +### 配置使用ONNXRuntime引擎 +除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。ONNXRuntime引擎中,还可以进行如下配置: + +``` +optimization { + execution_accelerators { + cpu_execution_accelerator : [ + { + name : "onnxruntime" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + } + ], + gpu_execution_accelerator : [ + { + name : "onnxruntime" + } + ] + } +} +``` + +### 配置使用OpenVINO引擎 +OpenVINO引擎只支持CPU推理,配置如下: + +``` +optimization { + execution_accelerators { + cpu_execution_accelerator : [ + { + name : "openvino" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + } + ] + } +} +``` + +### 配置使用TensorRT引擎 +TensorRT引擎只支持GPU推理,配置如下: + +``` +optimization { + execution_accelerators { + gpu_execution_accelerator : [ + { + name : "tensorrt" + # 使用TensorRT的FP16推理,其他可选项为: trt_fp32、trt_int8 + parameters { key: "precision" value: "trt_fp16" } + } + ] + } +} +``` diff --git a/serving/docs/zh_CN/model_repository.md b/serving/docs/zh_CN/model_repository.md new file mode 100644 index 000000000..adff771ff --- /dev/null +++ b/serving/docs/zh_CN/model_repository.md @@ -0,0 +1,78 @@ +# 模型仓库(Model Repository) + +FastDeploy启动服务时指定模型仓库中一个或多个模型部署服务。当服务运行时,可以用[Model Management](https://github.com/triton-inference-server/server/blob/main/docs/model_management.md)中描述的方式修改服务中的模型。 +从服务器启动时指定的一个或多个模型存储库中为模型提供服务 + +## 仓库结构 +模型仓库路径通过FastDeploy启动时的*--model-repository*选项指定,可以多次指定*--model-repository*选项来加载多个仓库。例如: + +``` +$ fastdeploy --model-repository= +``` + +模型仓库的结构必须按以下的格式创建: +``` + / + / + [config.pbtxt] + [ ...] + / + + / + + ... + / + [config.pbtxt] + [ ...] + / + + / + + ... + ... +``` +在最顶层``模型仓库目录下,必须有0个或多个``模型名字的子目录。每个``模型名字子目录包含部署模型相应的信息,多个表示模型版本的数字子目录和一个描述模型配置的*config.pbtxt*文件。 + +Paddle模型存在版本号子目录中,必须为`model.pdmodel`文件和`model.pdiparams`文件。 + +## 模型版本 +每个模型在仓库中可以有一个或多个可用的版本,模型目录中以数字命名的子目录就是对应的版本,数字即版本号。没有以数字命名的子目录,或以*0*开头的子目录都会被忽略。模型配置文件中可以指定[版本策略](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#version-policy),控制Triton启动模型目录中的哪个版本。 + +## 模型仓库示例 +部署Paddle模型时需要的模型必须是2.0版本以上导出的推理模型,模型包含`model.pdmodel`和`model.pdiparams`两个文件放在版本目录中。 + +部署Paddle模型的最小模型仓库目录示例: +``` + / + / + config.pbtxt + 1/ + model.pdmodel + model.pdiparams + + # 真实例子: + models + └── ResNet50 + ├── 1 + │   ├── model.pdiparams + │   └── model.pdmodel + └── config.pbtxt +``` + +部署ONNX模型,必须要在版本目录中包含`model.onnx`名字的模型。 + +部署ONNX模型的最小模型仓库目录示例: +``` + / + / + config.pbtxt + 1/ + model.onnx + + # 真实例子: + models + └── ResNet50 + ├── 1 + │   ├── model.onnx + └── config.pbtxt +``` diff --git a/serving/scripts/build.sh b/serving/scripts/build.sh new file mode 100644 index 000000000..f03ed7c90 --- /dev/null +++ b/serving/scripts/build.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +sh build_fd_vison.sh +sh build_fd_runtime.sh +sh build_fd_backend.sh diff --git a/serving/scripts/build_fd_backend.sh b/serving/scripts/build_fd_backend.sh new file mode 100644 index 000000000..7eb639af1 --- /dev/null +++ b/serving/scripts/build_fd_backend.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if [ ! -d "./cmake-3.18.6-Linux-x86_64/" ]; then + wget https://github.com/Kitware/CMake/releases/download/v3.18.6/cmake-3.18.6-Linux-x86_64.tar.gz + tar -zxvf cmake-3.18.6-Linux-x86_64.tar.gz + rm -rf cmake-3.18.6-Linux-x86_64.tar.gz +fi + +docker run -it --rm --name build_fd_backend \ + -v`pwd`:/workspace/fastdeploy \ + nvcr.io/nvidia/tritonserver:21.10-py3 \ + bash -c \ + 'cd /workspace/fastdeploy/serving; + rm -rf build; mkdir build; cd build; + apt-get update; apt-get install -y --no-install-recommends rapidjson-dev; + export PATH=/workspace/fastdeploy/cmake-3.18.6-Linux-x86_64/bin:$PATH; + cmake .. -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy-0.0.3 -DTRITON_COMMON_REPO_TAG=r21.10 -DTRITON_CORE_REPO_TAG=r21.10 -DTRITON_BACKEND_REPO_TAG=r21.10; make -j`nproc`' diff --git a/serving/scripts/build_fd_runtime.sh b/serving/scripts/build_fd_runtime.sh new file mode 100644 index 000000000..4f3df0b88 --- /dev/null +++ b/serving/scripts/build_fd_runtime.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if [ ! -d "./cmake-3.18.6-Linux-x86_64/" ]; then + wget https://github.com/Kitware/CMake/releases/download/v3.18.6/cmake-3.18.6-Linux-x86_64.tar.gz + tar -zxvf cmake-3.18.6-Linux-x86_64.tar.gz + rm -rf cmake-3.18.6-Linux-x86_64.tar.gz +fi + +if [ ! -d "./TensorRT-8.4.1.5/" ]; then + wget https://fastdeploy.bj.bcebos.com/third_libs/TensorRT-8.4.1.5.Linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz + tar -zxvf TensorRT-8.4.1.5.Linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz + rm -rf TensorRT-8.4.1.5.Linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz +fi + +docker run -it --rm --name build_fd_runtime \ + -v`pwd`:/workspace/fastdeploy \ + nvcr.io/nvidia/tritonserver:21.10-py3-min \ + bash -c \ + 'cd /workspace/fastdeploy; + rm -rf build; mkdir build; cd build; + apt-get update; + apt-get install -y --no-install-recommends python3-dev python3-pip; + ln -s /usr/bin/python3 /usr/bin/python; + export PATH=/workspace/fastdeploy/cmake-3.18.6-Linux-x86_64/bin:$PATH; + cmake .. -DENABLE_TRT_BACKEND=ON -DCMAKE_INSTALL_PREFIX=${PWD}/fastdeploy-0.0.3 -DWITH_GPU=ON -DTRT_DIRECTORY=${PWD}/../TensorRT-8.4.1.5/ -DENABLE_PADDLE_BACKEND=ON -DENABLE_ORT_BACKEND=ON -DENABLE_OPENVINO_BACKEND=ON -DENABLE_VISION=OFF -DBUILD_FASTDEPLOY_PYTHON=OFF -DENABLE_PADDLE_FRONTEND=ON -DENABLE_TEXT=OFF -DLIBRARY_NAME=fastdeploy_runtime; + make -j`nproc`; + make install' diff --git a/serving/scripts/build_fd_vison.sh b/serving/scripts/build_fd_vison.sh new file mode 100644 index 000000000..fd55d255f --- /dev/null +++ b/serving/scripts/build_fd_vison.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if [ ! -d "./cmake-3.18.6-Linux-x86_64/" ]; then + wget https://github.com/Kitware/CMake/releases/download/v3.18.6/cmake-3.18.6-Linux-x86_64.tar.gz + tar -zxvf cmake-3.18.6-Linux-x86_64.tar.gz + rm -rf cmake-3.18.6-Linux-x86_64.tar.gz +fi + +docker run -it --rm --name build_fd_vison \ + -v`pwd`:/workspace/fastdeploy \ + nvcr.io/nvidia/tritonserver:21.10-py3-min \ + bash -c \ + 'cd /workspace/fastdeploy/python; + rm -rf .setuptools-cmake-build dist; + apt-get update; + apt-get install -y --no-install-recommends patchelf python3-dev python3-pip; + ln -s /usr/bin/python3 /usr/bin/python; + export PATH=/workspace/fastdeploy/cmake-3.18.6-Linux-x86_64/bin:$PATH; + export WITH_GPU=ON; + export ENABLE_ORT_BACKEND=OFF; + export ENABLE_VISION=ON; + export ENABLE_TEXT=ON; + python setup.py build; + python setup.py bdist_wheel' diff --git a/serving/src/fastdeploy_backend_utils.cc b/serving/src/fastdeploy_backend_utils.cc index 2de0baf37..2d151253e 100644 --- a/serving/src/fastdeploy_backend_utils.cc +++ b/serving/src/fastdeploy_backend_utils.cc @@ -123,6 +123,27 @@ std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type) { return "TYPE_INVALID"; } +TRITONSERVER_Error* FDParseShape(triton::common::TritonJson::Value& io, + const std::string& name, + std::vector* shape) { + std::string shape_string; + RETURN_IF_ERROR(io.MemberAsString(name.c_str(), &shape_string)); + + std::vector str_shapes; + std::istringstream in(shape_string); + std::copy(std::istream_iterator(in), + std::istream_iterator(), + std::back_inserter(str_shapes)); + + std::transform(str_shapes.cbegin(), str_shapes.cend(), + std::back_inserter(*shape), + [](const std::string& str) -> int32_t { + return static_cast(std::stoll(str)); + }); + + return nullptr; // success +} + } // namespace fastdeploy_runtime } // namespace backend } // namespace triton \ No newline at end of file diff --git a/serving/src/fastdeploy_backend_utils.h b/serving/src/fastdeploy_backend_utils.h index 2a7cdd100..46cc516ac 100644 --- a/serving/src/fastdeploy_backend_utils.h +++ b/serving/src/fastdeploy_backend_utils.h @@ -33,6 +33,7 @@ #include #include "fastdeploy/core/fd_type.h" +#include "triton/backend/backend_common.h" #include "triton/core/tritonserver.h" namespace triton { @@ -67,6 +68,10 @@ fastdeploy::FDDataType ModelConfigDataTypeToFDType( std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type); +TRITONSERVER_Error* FDParseShape(triton::common::TritonJson::Value& io, + const std::string& name, + std::vector* shape); + } // namespace fastdeploy_runtime } // namespace backend } // namespace triton diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index c918d1e45..1051915ef 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -26,6 +26,7 @@ #include +#include #include #include @@ -169,83 +170,154 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) // instance when creating that instance's runtime. runtime_options_.reset(new fastdeploy::RuntimeOption()); + triton::common::TritonJson::Value optimization; + if (not ModelConfig().Find("optimization", &optimization)) { + return; + } + + triton::common::TritonJson::Value eas; + if (not optimization.Find("execution_accelerators", &eas)) { + return; + } + + // CPU execution providers { - triton::common::TritonJson::Value optimization; - if (ModelConfig().Find("optimization", &optimization)) { - triton::common::TritonJson::Value backend; - if (optimization.Find("onnxruntime", &backend)) { - runtime_options_->UseOrtBackend(); - std::vector param_keys; - THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); - for (const auto& param_key : param_keys) { - std::string value_string; - if (param_key == "graph_level") { + triton::common::TritonJson::Value cpu_eas; + if (eas.Find("cpu_execution_accelerator", &cpu_eas)) { + for (size_t idx = 0; idx < cpu_eas.ArraySize(); idx++) { + triton::common::TritonJson::Value ea; + THROW_IF_BACKEND_MODEL_ERROR(cpu_eas.IndexAsObject(idx, &ea)); + std::string name; + THROW_IF_BACKEND_MODEL_ERROR(ea.MemberAsString("name", &name)); + if (name == "onnxruntime") { + runtime_options_->UseOrtBackend(); + } else if (name == "paddle") { + runtime_options_->UsePaddleBackend(); + } else if (name == "openvino") { + runtime_options_->UseOpenVINOBackend(); + } else if (name != "") { + TRITONSERVER_Error* error = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("unknown cpu_execution_accelerator name '" + name + + "' is provided. Available choices are [onnxruntime, " + "paddle, openvino]") + .c_str()); + THROW_IF_BACKEND_MODEL_ERROR(error); + } + + triton::common::TritonJson::Value params; + if (ea.Find("parameters", ¶ms)) { + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - value_string, &runtime_options_->ort_graph_opt_level)); - } else if (param_key == "inter_op_num_threads") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - value_string, &runtime_options_->ort_inter_op_num_threads)); - } else if (param_key == "execution_mode") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - value_string, &runtime_options_->ort_execution_mode)); + params.MemberAsString(param_key.c_str(), &value_string)); + if (param_key == "cpu_threads") { + int cpu_thread_num; + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &cpu_thread_num)); + runtime_options_->SetCpuThreadNum(cpu_thread_num); + // } else if (param_key == "graph_level") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, &runtime_options_->ort_graph_opt_level)); + // } else if (param_key == "inter_op_num_threads") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, + // &runtime_options_->ort_inter_op_num_threads)); + // } else if (param_key == "execution_mode") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, &runtime_options_->ort_execution_mode)); + // } else if (param_key == "capacity") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, &runtime_options_->pd_mkldnn_cache_size)); + } else if (param_key == "use_mkldnn") { + bool pd_enable_mkldnn; + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &pd_enable_mkldnn)); + runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); + } } } - } else if (optimization.Find("tensorrt", &backend)) { - runtime_options_->UseTrtBackend(); - std::vector param_keys; - THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); - for (const auto& param_key : param_keys) { - std::string value_string; - if (param_key == "cpu_threads") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR( - ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); - } - // TODO(liqi): add tensorrt + } + } + } + + // GPU execution providers + { + triton::common::TritonJson::Value gpu_eas; + if (eas.Find("gpu_execution_accelerator", &gpu_eas)) { + for (size_t idx = 0; idx < gpu_eas.ArraySize(); idx++) { + triton::common::TritonJson::Value ea; + THROW_IF_BACKEND_MODEL_ERROR(gpu_eas.IndexAsObject(idx, &ea)); + std::string name; + THROW_IF_BACKEND_MODEL_ERROR(ea.MemberAsString("name", &name)); + + if (name == "onnxruntime") { + runtime_options_->UseOrtBackend(); + } else if (name == "paddle") { + runtime_options_->UsePaddleBackend(); + } else if (name == "tensorrt") { + runtime_options_->UseTrtBackend(); } - } else if (optimization.Find("paddle", &backend)) { - runtime_options_->UsePaddleBackend(); - std::vector param_keys; - THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); - for (const auto& param_key : param_keys) { - std::string value_string; - if (param_key == "cpu_threads") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR( - ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); - } else if (param_key == "capacity") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - value_string, &runtime_options_->pd_mkldnn_cache_size)); - } else if (param_key == "use_mkldnn") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR(ParseBoolValue( - value_string, &runtime_options_->pd_enable_mkldnn)); + if (name == "min_shape" or name == "max_shape" or name == "opt_shape") { + triton::common::TritonJson::Value params; + if (ea.Find("parameters", ¶ms)) { + std::vector input_names; + THROW_IF_BACKEND_MODEL_ERROR(params.Members(&input_names)); + for (const auto& input_name : input_names) { + std::vector shape; + FDParseShape(params, input_name, &shape); + if (name == "min_shape") { + runtime_options_->trt_min_shape[input_name] = shape; + } else if (name == "max_shape") { + runtime_options_->trt_max_shape[input_name] = shape; + } else { + runtime_options_->trt_opt_shape[input_name] = shape; + } + } } - } - } else if (optimization.Find("openvino", &backend)) { - runtime_options_->UseOpenVINOBackend(); - std::vector param_keys; - THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); - for (const auto& param_key : param_keys) { - std::string value_string; - if (param_key == "cpu_threads") { - THROW_IF_BACKEND_MODEL_ERROR( - backend.MemberAsString(param_key.c_str(), &value_string)); - THROW_IF_BACKEND_MODEL_ERROR( - ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); + } else { + triton::common::TritonJson::Value params; + if (ea.Find("parameters", ¶ms)) { + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; + THROW_IF_BACKEND_MODEL_ERROR( + params.MemberAsString(param_key.c_str(), &value_string)); + // if (param_key == "graph_level") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, &runtime_options_->ort_graph_opt_level)); + // } else if (param_key == "inter_op_num_threads") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, + // &runtime_options_->ort_inter_op_num_threads)); + // } else if (param_key == "execution_mode") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + // value_string, &runtime_options_->ort_execution_mode)); + // } + if (param_key == "precision") { + std::transform(value_string.begin(), value_string.end(), + value_string.begin(), ::tolower); + if (value_string == "trt_fp16") { + runtime_options_->EnableTrtFP16(); + } else if (value_string == "trt_int8") { + // TODO(liqi): use EnableTrtINT8 + runtime_options_->trt_enable_int8 = true; + } + // } else if( param_key == "max_batch_size") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( + // value_string, &runtime_options_->trt_max_batch_size)); + // } else if( param_key == "workspace_size") { + // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( + // value_string, + // &runtime_options_->trt_max_workspace_size)); + } else if (param_key == "cache_file") { + runtime_options_->SetTrtCacheFile(value_string); + } + } } - // TODO(liqi): add openvino } } } @@ -285,11 +357,11 @@ TRITONSERVER_Error* ModelState::LoadModel( "not provided.'") .c_str()); } - runtime_options_->model_format = fastdeploy::Frontend::PADDLE; + runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE; runtime_options_->model_file = *model_path; runtime_options_->params_file = *params_path; } else { - runtime_options_->model_format = fastdeploy::Frontend::ONNX; + runtime_options_->model_format = fastdeploy::ModelFormat::ONNX; runtime_options_->model_file = *model_path; } }