diff --git a/examples/text/ernie-3.0/serving/README.md b/examples/text/ernie-3.0/serving/README.md new file mode 100644 index 000000000..df969724a --- /dev/null +++ b/examples/text/ernie-3.0/serving/README.md @@ -0,0 +1,171 @@ +# Ernie-3.0 服务化部署示例 + +## 准备模型 + +下载ERNIE 3.0的新闻分类模型、序列标注模型(如果有已训练好的模型,跳过此步骤): +```bash +# 下载并解压新闻分类模型 +wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/tnews_pruned_infer_model.zip +unzip tnews_pruned_infer_model.zip + +# 将下载的模型移动到分类任务的模型仓库目录 +mv tnews_pruned_infer_model/float32.pdmodel models/ernie_seqcls_model/1/model.pdmodel +mv tnews_pruned_infer_model/float32.pdiparams models/ernie_seqcls_model/1/model.pdiparams + +# 下载并解压序列标注模型 +wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/msra_ner_pruned_infer_model.zip +unzip msra_ner_pruned_infer_model.zip + +# 将下载的模型移动到序列标注任务的模型仓库目录 +mv msra_ner_pruned_infer_model/float32.pdmodel models/ernie_tokencls_model/1/model.pdmodel +mv msra_ner_pruned_infer_model/float32.pdiparams models/ernie_tokencls_model/1/model.pdiparams +``` + +模型下载移动好之后,分类任务的models目录结构如下: +``` +models +├── ernie_seqcls # 分类任务的pipeline +│   ├── 1 +│   └── config.pbtxt # 通过这个文件组合前后处理和模型推理 +├── ernie_seqcls_model # 分类任务的模型推理 +│   ├── 1 +│   │   └── model.onnx +│   └── config.pbtxt +├── ernie_seqcls_postprocess # 分类任务后处理 +│   ├── 1 +│   │   └── model.py +│   └── config.pbtxt +└── ernie_tokenizer # 预处理分词 + ├── 1 + │   └── model.py + └── config.pbtxt +``` + +## 拉取并运行镜像 +```bash +# CPU镜像, 仅支持Paddle/ONNX模型在CPU上进行服务化部署,支持的推理后端包括OpenVINO、Paddle Inference和ONNX Runtime +docker pull paddlepaddle/fastdeploy:0.3.0-cpu-only-21.10 + +# GPU 镜像, 支持Paddle/ONNX模型在GPU/CPU上进行服务化部署,支持的推理后端包括OpenVINO、TensorRT、Paddle Inference和ONNX Runtime +docker pull paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 + +# 运行 +docker run -it --net=host --name fastdeploy_server --shm-size="1g" -v /path/serving/models:/models paddlepaddle/fastdeploy:0.3.0-cpu-only-21.10 bash +``` + +## 部署模型 +serving目录包含启动pipeline服务的配置和发送预测请求的代码,包括: + +``` +models # 服务化启动需要的模型仓库,包含模型和服务配置文件 +seq_cls_rpc_client.py # 新闻分类任务发送pipeline预测请求的脚本 +token_cls_rpc_client.py # 序列标注任务发送pipeline预测请求的脚本 +``` + +*注意*:启动服务时,Server的每个python后端进程默认申请`64M`内存,默认启动的docker无法启动多个python后端节点。有两个解决方案: +- 1.启动容器时设置`shm-size`参数, 比如:`docker run -it --net=host --name fastdeploy_server --shm-size="1g" -v /path/serving/models:/models paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 bash` +- 2.启动服务时设置python后端的`shm-default-byte-size`参数, 设置python后端的默认内存为10M: `tritonserver --model-repository=/models --backend-config=python,shm-default-byte-size=10485760` + +### 分类任务 +在容器内执行下面命令启动服务: +``` +# 默认启动models下所有模型 +fastdeployserver --model-repository=/models + +# 可通过参数只启动分类任务 +fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=ernie_seqcls +``` +输出打印如下: +``` +I1019 09:41:15.375496 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_tokenizer' version 1 +I1019 09:41:15.375987 2823 model_repository_manager.cc:1022] loading: ernie_seqcls:1 +I1019 09:41:15.477147 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_seqcls' version 1 +I1019 09:41:15.477325 2823 server.cc:522] +... +I0613 08:59:20.577820 10021 server.cc:592] ++----------------------------+---------+--------+ +| Model | Version | Status | ++----------------------------+---------+--------+ +| ernie_seqcls | 1 | READY | +| ernie_seqcls_model | 1 | READY | +| ernie_seqcls_postprocess | 1 | READY | +| ernie_tokenizer | 1 | READY | ++----------------------------+---------+--------+ +... +I0601 07:15:15.923270 8059 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001 +I0601 07:15:15.923604 8059 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000 +I0601 07:15:15.964984 8059 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002 +``` + +### 序列标注任务 +在容器内执行下面命令启动序列标注服务: +``` +fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=ernie_tokencls --backend-config=python,shm-default-byte-size=10485760 +``` +输出打印如下: +``` +I1019 09:41:15.375496 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_tokenizer' version 1 +I1019 09:41:15.375987 2823 model_repository_manager.cc:1022] loading: ernie_seqcls:1 +I1019 09:41:15.477147 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_seqcls' version 1 +I1019 09:41:15.477325 2823 server.cc:522] +... +I0613 08:59:20.577820 10021 server.cc:592] ++----------------------------+---------+--------+ +| Model | Version | Status | ++----------------------------+---------+--------+ +| ernie_tokencls | 1 | READY | +| ernie_tokencls_model | 1 | READY | +| ernie_tokencls_postprocess | 1 | READY | +| ernie_tokenizer | 1 | READY | ++----------------------------+---------+--------+ +... +I0601 07:15:15.923270 8059 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001 +I0601 07:15:15.923604 8059 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000 +I0601 07:15:15.964984 8059 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002 +``` + +## 客户端请求 +客户端请求可以在本地执行脚本请求;也可以在容器中执行。 + +本地执行脚本需要先安装依赖: +``` +pip install grpcio +pip install tritonclient[all] + +# 如果bash无法识别括号,可以使用如下指令安装: +pip install tritonclient\[all\] +``` + +### 分类任务 +注意执行客户端请求时关闭代理,并根据实际情况修改main函数中的ip地址(启动服务所在的机器) +``` +python seq_cls_grpc_client.py +``` +输出打印如下: +``` +{'label': array([5, 9]), 'confidence': array([0.6425664 , 0.66534853], dtype=float32)} +{'label': array([4]), 'confidence': array([0.53198355], dtype=float32)} +acc: 0.5731 +``` + +### 序列标注任务 +注意执行客户端请求时关闭代理,并根据实际情况修改main函数中的ip地址(启动服务所在的机器) +``` +python token_cls_grpc_client.py +``` +输出打印如下: +``` +input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。 +The model detects all entities: +entity: 北京 label: LOC pos: [0, 1] +entity: 重庆 label: LOC pos: [6, 7] +entity: 成都 label: LOC pos: [12, 13] +input data: 原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。 +The model detects all entities: +entity: 玛雅 label: LOC pos: [2, 3] +entity: 华夏 label: LOC pos: [14, 15] +``` + +## 配置修改 + +当前分类任务(ernie_seqcls_model/config.pbtxt)默认配置在CPU上运行OpenVINO引擎; 序列标注任务默认配置在GPU上运行Paddle引擎。如果要在CPU/GPU或其他推理引擎上运行, 需要修改配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md) diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls/1/README.md b/examples/text/ernie-3.0/serving/models/ernie_seqcls/1/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_seqcls/config.pbtxt new file mode 100644 index 000000000..bea846e76 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_seqcls/config.pbtxt @@ -0,0 +1,75 @@ +name: "ernie_seqcls" +platform: "ensemble" +max_batch_size: 64 +input [ + { + name: "INPUT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "label" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "confidence" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "ernie_tokenizer" + model_version: 1 + input_map { + key: "INPUT_0" + value: "INPUT" + } + output_map { + key: "OUTPUT_0" + value: "tokenizer_input_ids" + } + output_map { + key: "OUTPUT_1" + value: "tokenizer_token_type_ids" + } + }, + { + model_name: "ernie_seqcls_model" + model_version: 1 + input_map { + key: "input_ids" + value: "tokenizer_input_ids" + } + input_map { + key: "token_type_ids" + value: "tokenizer_token_type_ids" + } + output_map { + key: "linear_113.tmp_1" + value: "OUTPUT_2" + } + }, + { + model_name: "ernie_seqcls_postprocess" + model_version: 1 + input_map { + key: "POST_INPUT" + value: "OUTPUT_2" + } + output_map { + key: "POST_label" + value: "label" + } + output_map { + key: "POST_confidence" + value: "confidence" + } + } + ] +} + diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/1/README.md b/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/1/README.md new file mode 100644 index 000000000..aaca8a9ec --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/1/README.md @@ -0,0 +1 @@ +本目录存放Ernie-3.0模型 diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/config.pbtxt new file mode 100755 index 000000000..529262651 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_seqcls_model/config.pbtxt @@ -0,0 +1,42 @@ +backend: "fastdeploy" +max_batch_size: 64 +input [ + { + name: "input_ids" + data_type: TYPE_INT64 + dims: [ -1 ] + }, + { + name: "token_type_ids" + data_type: TYPE_INT64 + dims: [ -1 ] + } +] +output [ + { + name: "linear_113.tmp_1" + data_type: TYPE_FP32 + dims: [ 15 ] + } +] + +instance_group [ + { + # 创建1个实例 + count: 1 + # 使用CPU推理(KIND_CPU、KIND_GPU) + kind: KIND_CPU + } +] + +optimization { + execution_accelerators { + cpu_execution_accelerator : [ + { + # use openvino backend + name: "openvino" + parameters { key: "cpu_threads" value: "5" } + } + ] + } +} diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/1/model.py b/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/1/model.py new file mode 100644 index 000000000..2adce7682 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/1/model.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import paddle +import numpy as np +import time + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.model_config = model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("input:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("output:", self.output_names) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + # print("num:", len(requests), flush=True) + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + # print("post data:", data) + max_value = np.max(data, axis=1, keepdims=True) + exp_data = np.exp(data - max_value) + probs = exp_data / np.sum(exp_data, axis=1, keepdims=True) + probs = probs.max(axis=-1) + # print("label:", data.argmax(axis=-1)) + # print("probs:", probs) + out_tensor1 = pb_utils.Tensor( + self.output_names[0], data.argmax(axis=-1)) + out_tensor2 = pb_utils.Tensor(self.output_names[1], probs) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor1, out_tensor2]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/config.pbtxt new file mode 100644 index 000000000..09d71da7e --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_seqcls_postprocess/config.pbtxt @@ -0,0 +1,31 @@ +name: "ernie_seqcls_postprocess" +backend: "python" +max_batch_size: 64 + +input [ + { + name: "POST_INPUT" + data_type: TYPE_FP32 + dims: [ 15 ] + } +] + +output [ + { + name: "POST_label" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "POST_confidence" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls/1/README.md b/examples/text/ernie-3.0/serving/models/ernie_tokencls/1/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_tokencls/config.pbtxt new file mode 100644 index 000000000..d571c0cf0 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokencls/config.pbtxt @@ -0,0 +1,66 @@ +name: "ernie_tokencls" +platform: "ensemble" +max_batch_size: 64 +input [ + { + name: "INPUT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "ernie_tokenizer" + model_version: 1 + input_map { + key: "INPUT_0" + value: "INPUT" + } + output_map { + key: "OUTPUT_0" + value: "tokenizer_input_ids" + } + output_map { + key: "OUTPUT_1" + value: "tokenizer_token_type_ids" + } + }, + { + model_name: "ernie_tokencls_model" + model_version: 1 + input_map { + key: "input_ids" + value: "tokenizer_input_ids" + } + input_map { + key: "token_type_ids" + value: "tokenizer_token_type_ids" + } + output_map { + key: "linear_113.tmp_1" + value: "OUTPUT_2" + } + }, + { + model_name: "ernie_tokencls_postprocess" + model_version: 1 + input_map { + key: "POST_INPUT" + value: "OUTPUT_2" + } + output_map { + key: "POST_OUTPUT" + value: "OUTPUT" + } + } + ] +} + diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/1/README.md b/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/1/README.md new file mode 100644 index 000000000..aaca8a9ec --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/1/README.md @@ -0,0 +1 @@ +本目录存放Ernie-3.0模型 diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/config.pbtxt new file mode 100755 index 000000000..373e0197e --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokencls_model/config.pbtxt @@ -0,0 +1,40 @@ +backend: "fastdeploy" +max_batch_size: 64 +input [ + { + name: "input_ids" + data_type: TYPE_INT64 + dims: [ -1 ] + }, + { + name: "token_type_ids" + data_type: TYPE_INT64 + dims: [ -1 ] + } +] +output [ + { + name: "linear_113.tmp_1" + data_type: TYPE_FP32 + dims: [ -1, 7 ] + } +] + +instance_group [ + { + # 创建1个实例 + count: 1 + # 使用GPU推理(KIND_CPU、KIND_GPU) + kind: KIND_GPU + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ + { + name: "paddle" + } + ] + } +} diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/1/model.py b/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/1/model.py new file mode 100644 index 000000000..518a81f43 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/1/model.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import paddle +import numpy as np +import time + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.model_config = model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("input:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("output:", self.output_names) + # The label names of NER models trained by different data sets may be different + self.label_names = [ + 'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC' + ] + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + # print("num:", len(requests), flush=True) + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + # print("post data:", data) + tokens_label = data.argmax(axis=-1).tolist() + value = [] + for _, token_label in enumerate(tokens_label): + start = -1 + label_name = "" + items = [] + for i, label in enumerate(token_label): + if self.label_names[label] == "O" and start >= 0: + items.append({ + "pos": [start, i - 2], + "label": label_name, + }) + start = -1 + elif "B-" in self.label_names[label]: + start = i - 1 + label_name = self.label_names[label][2:] + if start >= 0: + items.append({ + "pos": [start, len(token_label) - 1], + "label": label_name, + }) + value.append(items) + out_result = np.array(value, dtype='object') + out_tensor = pb_utils.Tensor(self.output_names[0], out_result) + inference_response = pb_utils.InferenceResponse(output_tensors=[ + out_tensor, + ]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/config.pbtxt new file mode 100644 index 000000000..776015827 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokencls_postprocess/config.pbtxt @@ -0,0 +1,26 @@ +name: "ernie_tokencls_postprocess" +backend: "python" +max_batch_size: 64 + +input [ + { + name: "POST_INPUT" + data_type: TYPE_FP32 + dims: [ -1, 7 ] + } +] + +output [ + { + name: "POST_OUTPUT" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokenizer/1/model.py b/examples/text/ernie-3.0/serving/models/ernie_tokenizer/1/model.py new file mode 100644 index 000000000..2a4f6d42a --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokenizer/1/model.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import paddle +import numpy as np +import time + +from paddlenlp.transformers import AutoTokenizer + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.tokenizer = AutoTokenizer.from_pretrained( + "ernie-3.0-medium-zh", use_faster=True) + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("input:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("output:", self.output_names) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + # print("num:", len(requests), flush=True) + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + data = [i[0].decode('utf-8') for i in data] + data = self.tokenizer( + data, max_length=128, padding=True, truncation=True) + input_ids = np.array(data["input_ids"], dtype=self.output_dtype[0]) + token_type_ids = np.array( + data["token_type_ids"], dtype=self.output_dtype[1]) + + # print("input_ids:", input_ids) + # print("token_type_ids:", token_type_ids) + + out_tensor1 = pb_utils.Tensor(self.output_names[0], input_ids) + out_tensor2 = pb_utils.Tensor(self.output_names[1], token_type_ids) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor1, out_tensor2]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/text/ernie-3.0/serving/models/ernie_tokenizer/config.pbtxt b/examples/text/ernie-3.0/serving/models/ernie_tokenizer/config.pbtxt new file mode 100644 index 000000000..078f8f3d5 --- /dev/null +++ b/examples/text/ernie-3.0/serving/models/ernie_tokenizer/config.pbtxt @@ -0,0 +1,31 @@ +name: "ernie_tokenizer" +backend: "python" +max_batch_size: 64 + +input [ + { + name: "INPUT_0" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +output [ + { + name: "OUTPUT_0" + data_type: TYPE_INT64 + dims: [ -1 ] + }, + { + name: "OUTPUT_1" + data_type: TYPE_INT64 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/examples/text/ernie-3.0/serving/seq_cls_grpc_client.py b/examples/text/ernie-3.0/serving/seq_cls_grpc_client.py new file mode 100755 index 000000000..1a32c467c --- /dev/null +++ b/examples/text/ernie-3.0/serving/seq_cls_grpc_client.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import time +from typing import Optional + +from tritonclient import utils as client_utils +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2 + +LOGGER = logging.getLogger("run_inference_on_triton") + + +class SyncGRPCTritonRunner: + DEFAULT_MAX_RESP_WAIT_S = 120 + + def __init__( + self, + server_url: str, + model_name: str, + model_version: str, + *, + verbose=False, + resp_wait_s: Optional[float]=None, ): + self._server_url = server_url + self._model_name = model_name + self._model_version = model_version + self._verbose = verbose + self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s + + self._client = InferenceServerClient( + self._server_url, verbose=self._verbose) + error = self._verify_triton_state(self._client) + if error: + raise RuntimeError( + f"Could not communicate to Triton Server: {error}") + + LOGGER.debug( + f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " + f"are up and ready!") + + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + LOGGER.info(f"Model config {model_config}") + LOGGER.info(f"Model metadata {model_metadata}") + + self._inputs = {tm.name: tm for tm in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = {tm.name: tm for tm in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def Run(self, inputs): + """ + Args: + inputs: list, Each value corresponds to an input name of self._input_names + Returns: + results: dict, {name : numpy.array} + """ + infer_inputs = [] + for idx, data in enumerate(inputs): + data = np.array( + [[x.encode('utf-8')] for x in data], dtype=np.object_) + infer_input = InferInput(self._input_names[idx], [len(data), 1], + "BYTES") + infer_input.set_data_from_numpy(data) + infer_inputs.append(infer_input) + + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=infer_inputs, + outputs=self._outputs_req, + client_timeout=self._response_wait_t, ) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + def _verify_triton_state(self, triton_client): + if not triton_client.is_server_live(): + return f"Triton server {self._server_url} is not live" + elif not triton_client.is_server_ready(): + return f"Triton server {self._server_url} is not ready" + elif not triton_client.is_model_ready(self._model_name, + self._model_version): + return f"Model {self._model_name}:{self._model_version} is not ready" + return None + + +def test_tnews_dataset(runner): + from paddlenlp.datasets import load_dataset + dev_ds = load_dataset('clue', "tnews", splits='dev') + + batches = [] + labels = [] + idx = 0 + batch_size = 32 + while idx < len(dev_ds): + data = [] + label = [] + for i in range(batch_size): + if idx + i >= len(dev_ds): + break + data.append(dev_ds[idx + i]["sentence"]) + label.append(dev_ds[idx + i]["label"]) + batches.append(data) + labels.append(np.array(label)) + idx += batch_size + + accuracy = 0 + for i, data in enumerate(batches): + ret = runner.Run([data]) + # print("ret:", ret) + accuracy += np.sum(labels[i] == ret["label"]) + print("acc:", 1.0 * accuracy / len(dev_ds)) + + +if __name__ == "__main__": + from paddlenlp.datasets import load_dataset + dev_ds = load_dataset('clue', "tnews", splits='dev') + model_name = "ernie_seqcls" + model_version = "1" + url = "localhost:8001" + runner = SyncGRPCTritonRunner(url, model_name, model_version) + texts = [["你家拆迁,要钱还是要房?答案一目了然", "军嫂探亲拧包入住,部队家属临时来队房标准有了规定,全面落实!"], [ + "区块链投资心得,能做到就不会亏钱", + ]] + + for text in texts: + # input format:[input1, input2 ... inputn], n = len(self._input_names) + result = runner.Run([text]) + print(result) + + test_tnews_dataset(runner) diff --git a/examples/text/ernie-3.0/serving/token_cls_grpc_client.py b/examples/text/ernie-3.0/serving/token_cls_grpc_client.py new file mode 100755 index 000000000..2fd0b6d5c --- /dev/null +++ b/examples/text/ernie-3.0/serving/token_cls_grpc_client.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import logging +import numpy as np +import time +from typing import Optional + +from tritonclient import utils as client_utils +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2 + +LOGGER = logging.getLogger("run_inference_on_triton") + + +class SyncGRPCTritonRunner: + DEFAULT_MAX_RESP_WAIT_S = 120 + + def __init__( + self, + server_url: str, + model_name: str, + model_version: str, + *, + verbose=False, + resp_wait_s: Optional[float]=None, ): + self._server_url = server_url + self._model_name = model_name + self._model_version = model_version + self._verbose = verbose + self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s + + self._client = InferenceServerClient( + self._server_url, verbose=self._verbose) + error = self._verify_triton_state(self._client) + if error: + raise RuntimeError( + f"Could not communicate to Triton Server: {error}") + + LOGGER.debug( + f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " + f"are up and ready!") + + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + LOGGER.info(f"Model config {model_config}") + LOGGER.info(f"Model metadata {model_metadata}") + + self._inputs = {tm.name: tm for tm in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = {tm.name: tm for tm in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def Run(self, inputs): + """ + Args: + inputs: list, Each value corresponds to an input name of self._input_names + Returns: + results: dict, {name : numpy.array} + """ + infer_inputs = [] + for idx, data in enumerate(inputs): + data = np.array( + [[x.encode('utf-8')] for x in data], dtype=np.object_) + infer_input = InferInput(self._input_names[idx], [len(data), 1], + "BYTES") + infer_input.set_data_from_numpy(data) + infer_inputs.append(infer_input) + + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=infer_inputs, + outputs=self._outputs_req, + client_timeout=self._response_wait_t, ) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + def _verify_triton_state(self, triton_client): + if not triton_client.is_server_live(): + return f"Triton server {self._server_url} is not live" + elif not triton_client.is_server_ready(): + return f"Triton server {self._server_url} is not ready" + elif not triton_client.is_model_ready(self._model_name, + self._model_version): + return f"Model {self._model_name}:{self._model_version} is not ready" + return None + + +if __name__ == "__main__": + model_name = "ernie_tokencls" + model_version = "1" + url = "localhost:8001" + runner = SyncGRPCTritonRunner(url, model_name, model_version) + dataset = [[ + "北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。", + "原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。", + ], ] + + for batch_input in dataset: + # input format:[input1, input2 ... inputn], n = len(self._input_names) + result = runner.Run([batch_input]) + for i, ret in enumerate(result['OUTPUT']): + ret = ast.literal_eval(ret.decode('utf-8')) + print("input data:", batch_input[i]) + print("The model detects all entities:") + for iterm in ret: + print("entity:", + batch_input[i][iterm["pos"][0]:iterm["pos"][1] + 1], + " label:", iterm["label"], " pos:", iterm["pos"]) diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 72d7da119..4d9a6d17b 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -236,6 +236,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) THROW_IF_BACKEND_MODEL_ERROR( ParseBoolValue(value_string, &pd_enable_mkldnn)); runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); + } else if (param_key == "use_paddle_log") { + runtime_options_->EnablePaddleLogInfo(); } } } @@ -305,6 +307,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } else if (value_string == "trt_int8") { // TODO(liqi): use EnableTrtINT8 runtime_options_->trt_enable_int8 = true; + } else if (value_string == "pd_fp16") { + // TODO(liqi): paddle inference don't currently have interface for fp16. } // } else if( param_key == "max_batch_size") { // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( @@ -317,6 +321,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) runtime_options_->SetTrtCacheFile(value_string); } else if (param_key == "use_paddle") { runtime_options_->EnablePaddleToTrt(); + } else if (param_key == "use_paddle_log") { + runtime_options_->EnablePaddleLogInfo(); } } }