mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
171
examples/text/ernie-3.0/serving/README.md
Normal file
171
examples/text/ernie-3.0/serving/README.md
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# Ernie-3.0 服务化部署示例
|
||||||
|
|
||||||
|
## 准备模型
|
||||||
|
|
||||||
|
下载ERNIE 3.0的新闻分类模型、序列标注模型(如果有已训练好的模型,跳过此步骤):
|
||||||
|
```bash
|
||||||
|
# 下载并解压新闻分类模型
|
||||||
|
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/tnews_pruned_infer_model.zip
|
||||||
|
unzip tnews_pruned_infer_model.zip
|
||||||
|
|
||||||
|
# 将下载的模型移动到分类任务的模型仓库目录
|
||||||
|
mv tnews_pruned_infer_model/float32.pdmodel models/ernie_seqcls_model/1/model.pdmodel
|
||||||
|
mv tnews_pruned_infer_model/float32.pdiparams models/ernie_seqcls_model/1/model.pdiparams
|
||||||
|
|
||||||
|
# 下载并解压序列标注模型
|
||||||
|
wget https://paddlenlp.bj.bcebos.com/models/transformers/ernie_3.0/msra_ner_pruned_infer_model.zip
|
||||||
|
unzip msra_ner_pruned_infer_model.zip
|
||||||
|
|
||||||
|
# 将下载的模型移动到序列标注任务的模型仓库目录
|
||||||
|
mv msra_ner_pruned_infer_model/float32.pdmodel models/ernie_tokencls_model/1/model.pdmodel
|
||||||
|
mv msra_ner_pruned_infer_model/float32.pdiparams models/ernie_tokencls_model/1/model.pdiparams
|
||||||
|
```
|
||||||
|
|
||||||
|
模型下载移动好之后,分类任务的models目录结构如下:
|
||||||
|
```
|
||||||
|
models
|
||||||
|
├── ernie_seqcls # 分类任务的pipeline
|
||||||
|
│ ├── 1
|
||||||
|
│ └── config.pbtxt # 通过这个文件组合前后处理和模型推理
|
||||||
|
├── ernie_seqcls_model # 分类任务的模型推理
|
||||||
|
│ ├── 1
|
||||||
|
│ │ └── model.onnx
|
||||||
|
│ └── config.pbtxt
|
||||||
|
├── ernie_seqcls_postprocess # 分类任务后处理
|
||||||
|
│ ├── 1
|
||||||
|
│ │ └── model.py
|
||||||
|
│ └── config.pbtxt
|
||||||
|
└── ernie_tokenizer # 预处理分词
|
||||||
|
├── 1
|
||||||
|
│ └── model.py
|
||||||
|
└── config.pbtxt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 拉取并运行镜像
|
||||||
|
```bash
|
||||||
|
# CPU镜像, 仅支持Paddle/ONNX模型在CPU上进行服务化部署,支持的推理后端包括OpenVINO、Paddle Inference和ONNX Runtime
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.3.0-cpu-only-21.10
|
||||||
|
|
||||||
|
# GPU 镜像, 支持Paddle/ONNX模型在GPU/CPU上进行服务化部署,支持的推理后端包括OpenVINO、TensorRT、Paddle Inference和ONNX Runtime
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10
|
||||||
|
|
||||||
|
# 运行
|
||||||
|
docker run -it --net=host --name fastdeploy_server --shm-size="1g" -v /path/serving/models:/models paddlepaddle/fastdeploy:0.3.0-cpu-only-21.10 bash
|
||||||
|
```
|
||||||
|
|
||||||
|
## 部署模型
|
||||||
|
serving目录包含启动pipeline服务的配置和发送预测请求的代码,包括:
|
||||||
|
|
||||||
|
```
|
||||||
|
models # 服务化启动需要的模型仓库,包含模型和服务配置文件
|
||||||
|
seq_cls_rpc_client.py # 新闻分类任务发送pipeline预测请求的脚本
|
||||||
|
token_cls_rpc_client.py # 序列标注任务发送pipeline预测请求的脚本
|
||||||
|
```
|
||||||
|
|
||||||
|
*注意*:启动服务时,Server的每个python后端进程默认申请`64M`内存,默认启动的docker无法启动多个python后端节点。有两个解决方案:
|
||||||
|
- 1.启动容器时设置`shm-size`参数, 比如:`docker run -it --net=host --name fastdeploy_server --shm-size="1g" -v /path/serving/models:/models paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 bash`
|
||||||
|
- 2.启动服务时设置python后端的`shm-default-byte-size`参数, 设置python后端的默认内存为10M: `tritonserver --model-repository=/models --backend-config=python,shm-default-byte-size=10485760`
|
||||||
|
|
||||||
|
### 分类任务
|
||||||
|
在容器内执行下面命令启动服务:
|
||||||
|
```
|
||||||
|
# 默认启动models下所有模型
|
||||||
|
fastdeployserver --model-repository=/models
|
||||||
|
|
||||||
|
# 可通过参数只启动分类任务
|
||||||
|
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=ernie_seqcls
|
||||||
|
```
|
||||||
|
输出打印如下:
|
||||||
|
```
|
||||||
|
I1019 09:41:15.375496 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_tokenizer' version 1
|
||||||
|
I1019 09:41:15.375987 2823 model_repository_manager.cc:1022] loading: ernie_seqcls:1
|
||||||
|
I1019 09:41:15.477147 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_seqcls' version 1
|
||||||
|
I1019 09:41:15.477325 2823 server.cc:522]
|
||||||
|
...
|
||||||
|
I0613 08:59:20.577820 10021 server.cc:592]
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
| Model | Version | Status |
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
| ernie_seqcls | 1 | READY |
|
||||||
|
| ernie_seqcls_model | 1 | READY |
|
||||||
|
| ernie_seqcls_postprocess | 1 | READY |
|
||||||
|
| ernie_tokenizer | 1 | READY |
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
...
|
||||||
|
I0601 07:15:15.923270 8059 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
|
||||||
|
I0601 07:15:15.923604 8059 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
|
||||||
|
I0601 07:15:15.964984 8059 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
|
||||||
|
```
|
||||||
|
|
||||||
|
### 序列标注任务
|
||||||
|
在容器内执行下面命令启动序列标注服务:
|
||||||
|
```
|
||||||
|
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=ernie_tokencls --backend-config=python,shm-default-byte-size=10485760
|
||||||
|
```
|
||||||
|
输出打印如下:
|
||||||
|
```
|
||||||
|
I1019 09:41:15.375496 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_tokenizer' version 1
|
||||||
|
I1019 09:41:15.375987 2823 model_repository_manager.cc:1022] loading: ernie_seqcls:1
|
||||||
|
I1019 09:41:15.477147 2823 model_repository_manager.cc:1183] successfully loaded 'ernie_seqcls' version 1
|
||||||
|
I1019 09:41:15.477325 2823 server.cc:522]
|
||||||
|
...
|
||||||
|
I0613 08:59:20.577820 10021 server.cc:592]
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
| Model | Version | Status |
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
| ernie_tokencls | 1 | READY |
|
||||||
|
| ernie_tokencls_model | 1 | READY |
|
||||||
|
| ernie_tokencls_postprocess | 1 | READY |
|
||||||
|
| ernie_tokenizer | 1 | READY |
|
||||||
|
+----------------------------+---------+--------+
|
||||||
|
...
|
||||||
|
I0601 07:15:15.923270 8059 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
|
||||||
|
I0601 07:15:15.923604 8059 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
|
||||||
|
I0601 07:15:15.964984 8059 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
|
||||||
|
```
|
||||||
|
|
||||||
|
## 客户端请求
|
||||||
|
客户端请求可以在本地执行脚本请求;也可以在容器中执行。
|
||||||
|
|
||||||
|
本地执行脚本需要先安装依赖:
|
||||||
|
```
|
||||||
|
pip install grpcio
|
||||||
|
pip install tritonclient[all]
|
||||||
|
|
||||||
|
# 如果bash无法识别括号,可以使用如下指令安装:
|
||||||
|
pip install tritonclient\[all\]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 分类任务
|
||||||
|
注意执行客户端请求时关闭代理,并根据实际情况修改main函数中的ip地址(启动服务所在的机器)
|
||||||
|
```
|
||||||
|
python seq_cls_grpc_client.py
|
||||||
|
```
|
||||||
|
输出打印如下:
|
||||||
|
```
|
||||||
|
{'label': array([5, 9]), 'confidence': array([0.6425664 , 0.66534853], dtype=float32)}
|
||||||
|
{'label': array([4]), 'confidence': array([0.53198355], dtype=float32)}
|
||||||
|
acc: 0.5731
|
||||||
|
```
|
||||||
|
|
||||||
|
### 序列标注任务
|
||||||
|
注意执行客户端请求时关闭代理,并根据实际情况修改main函数中的ip地址(启动服务所在的机器)
|
||||||
|
```
|
||||||
|
python token_cls_grpc_client.py
|
||||||
|
```
|
||||||
|
输出打印如下:
|
||||||
|
```
|
||||||
|
input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。
|
||||||
|
The model detects all entities:
|
||||||
|
entity: 北京 label: LOC pos: [0, 1]
|
||||||
|
entity: 重庆 label: LOC pos: [6, 7]
|
||||||
|
entity: 成都 label: LOC pos: [12, 13]
|
||||||
|
input data: 原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。
|
||||||
|
The model detects all entities:
|
||||||
|
entity: 玛雅 label: LOC pos: [2, 3]
|
||||||
|
entity: 华夏 label: LOC pos: [14, 15]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置修改
|
||||||
|
|
||||||
|
当前分类任务(ernie_seqcls_model/config.pbtxt)默认配置在CPU上运行OpenVINO引擎; 序列标注任务默认配置在GPU上运行Paddle引擎。如果要在CPU/GPU或其他推理引擎上运行, 需要修改配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md)
|
@@ -0,0 +1,75 @@
|
|||||||
|
name: "ernie_seqcls"
|
||||||
|
platform: "ensemble"
|
||||||
|
max_batch_size: 64
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "label"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ 1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "confidence"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
ensemble_scheduling {
|
||||||
|
step [
|
||||||
|
{
|
||||||
|
model_name: "ernie_tokenizer"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "INPUT_0"
|
||||||
|
value: "INPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "OUTPUT_0"
|
||||||
|
value: "tokenizer_input_ids"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "OUTPUT_1"
|
||||||
|
value: "tokenizer_token_type_ids"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "ernie_seqcls_model"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "input_ids"
|
||||||
|
value: "tokenizer_input_ids"
|
||||||
|
}
|
||||||
|
input_map {
|
||||||
|
key: "token_type_ids"
|
||||||
|
value: "tokenizer_token_type_ids"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "linear_113.tmp_1"
|
||||||
|
value: "OUTPUT_2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "ernie_seqcls_postprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "POST_INPUT"
|
||||||
|
value: "OUTPUT_2"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "POST_label"
|
||||||
|
value: "label"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "POST_confidence"
|
||||||
|
value: "confidence"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@@ -0,0 +1 @@
|
|||||||
|
本目录存放Ernie-3.0模型
|
42
examples/text/ernie-3.0/serving/models/ernie_seqcls_model/config.pbtxt
Executable file
42
examples/text/ernie-3.0/serving/models/ernie_seqcls_model/config.pbtxt
Executable file
@@ -0,0 +1,42 @@
|
|||||||
|
backend: "fastdeploy"
|
||||||
|
max_batch_size: 64
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "input_ids"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token_type_ids"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "linear_113.tmp_1"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 15 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
# 创建1个实例
|
||||||
|
count: 1
|
||||||
|
# 使用CPU推理(KIND_CPU、KIND_GPU)
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
cpu_execution_accelerator : [
|
||||||
|
{
|
||||||
|
# use openvino backend
|
||||||
|
name: "openvino"
|
||||||
|
parameters { key: "cpu_threads" value: "5" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
self.model_config = model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("input:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("output:", self.output_names)
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
data = data.as_numpy()
|
||||||
|
# print("post data:", data)
|
||||||
|
max_value = np.max(data, axis=1, keepdims=True)
|
||||||
|
exp_data = np.exp(data - max_value)
|
||||||
|
probs = exp_data / np.sum(exp_data, axis=1, keepdims=True)
|
||||||
|
probs = probs.max(axis=-1)
|
||||||
|
# print("label:", data.argmax(axis=-1))
|
||||||
|
# print("probs:", probs)
|
||||||
|
out_tensor1 = pb_utils.Tensor(
|
||||||
|
self.output_names[0], data.argmax(axis=-1))
|
||||||
|
out_tensor2 = pb_utils.Tensor(self.output_names[1], probs)
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[out_tensor1, out_tensor2])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,31 @@
|
|||||||
|
name: "ernie_seqcls_postprocess"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 64
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "POST_INPUT"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 15 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "POST_label"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ 1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "POST_confidence"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,66 @@
|
|||||||
|
name: "ernie_tokencls"
|
||||||
|
platform: "ensemble"
|
||||||
|
max_batch_size: 64
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "OUTPUT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
ensemble_scheduling {
|
||||||
|
step [
|
||||||
|
{
|
||||||
|
model_name: "ernie_tokenizer"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "INPUT_0"
|
||||||
|
value: "INPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "OUTPUT_0"
|
||||||
|
value: "tokenizer_input_ids"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "OUTPUT_1"
|
||||||
|
value: "tokenizer_token_type_ids"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "ernie_tokencls_model"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "input_ids"
|
||||||
|
value: "tokenizer_input_ids"
|
||||||
|
}
|
||||||
|
input_map {
|
||||||
|
key: "token_type_ids"
|
||||||
|
value: "tokenizer_token_type_ids"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "linear_113.tmp_1"
|
||||||
|
value: "OUTPUT_2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "ernie_tokencls_postprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "POST_INPUT"
|
||||||
|
value: "OUTPUT_2"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "POST_OUTPUT"
|
||||||
|
value: "OUTPUT"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@@ -0,0 +1 @@
|
|||||||
|
本目录存放Ernie-3.0模型
|
40
examples/text/ernie-3.0/serving/models/ernie_tokencls_model/config.pbtxt
Executable file
40
examples/text/ernie-3.0/serving/models/ernie_tokencls_model/config.pbtxt
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
backend: "fastdeploy"
|
||||||
|
max_batch_size: 64
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "input_ids"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token_type_ids"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "linear_113.tmp_1"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1, 7 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
# 创建1个实例
|
||||||
|
count: 1
|
||||||
|
# 使用GPU推理(KIND_CPU、KIND_GPU)
|
||||||
|
kind: KIND_GPU
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
gpu_execution_accelerator : [
|
||||||
|
{
|
||||||
|
name: "paddle"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
self.model_config = model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("input:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("output:", self.output_names)
|
||||||
|
# The label names of NER models trained by different data sets may be different
|
||||||
|
self.label_names = [
|
||||||
|
'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'
|
||||||
|
]
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
data = data.as_numpy()
|
||||||
|
# print("post data:", data)
|
||||||
|
tokens_label = data.argmax(axis=-1).tolist()
|
||||||
|
value = []
|
||||||
|
for _, token_label in enumerate(tokens_label):
|
||||||
|
start = -1
|
||||||
|
label_name = ""
|
||||||
|
items = []
|
||||||
|
for i, label in enumerate(token_label):
|
||||||
|
if self.label_names[label] == "O" and start >= 0:
|
||||||
|
items.append({
|
||||||
|
"pos": [start, i - 2],
|
||||||
|
"label": label_name,
|
||||||
|
})
|
||||||
|
start = -1
|
||||||
|
elif "B-" in self.label_names[label]:
|
||||||
|
start = i - 1
|
||||||
|
label_name = self.label_names[label][2:]
|
||||||
|
if start >= 0:
|
||||||
|
items.append({
|
||||||
|
"pos": [start, len(token_label) - 1],
|
||||||
|
"label": label_name,
|
||||||
|
})
|
||||||
|
value.append(items)
|
||||||
|
out_result = np.array(value, dtype='object')
|
||||||
|
out_tensor = pb_utils.Tensor(self.output_names[0], out_result)
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[
|
||||||
|
out_tensor,
|
||||||
|
])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,26 @@
|
|||||||
|
name: "ernie_tokencls_postprocess"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 64
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "POST_INPUT"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1, 7 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "POST_OUTPUT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,115 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
from paddlenlp.transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"ernie-3.0-medium-zh", use_faster=True)
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("input:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("output:", self.output_names)
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
data = data.as_numpy()
|
||||||
|
data = [i[0].decode('utf-8') for i in data]
|
||||||
|
data = self.tokenizer(
|
||||||
|
data, max_length=128, padding=True, truncation=True)
|
||||||
|
input_ids = np.array(data["input_ids"], dtype=self.output_dtype[0])
|
||||||
|
token_type_ids = np.array(
|
||||||
|
data["token_type_ids"], dtype=self.output_dtype[1])
|
||||||
|
|
||||||
|
# print("input_ids:", input_ids)
|
||||||
|
# print("token_type_ids:", token_type_ids)
|
||||||
|
|
||||||
|
out_tensor1 = pb_utils.Tensor(self.output_names[0], input_ids)
|
||||||
|
out_tensor2 = pb_utils.Tensor(self.output_names[1], token_type_ids)
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[out_tensor1, out_tensor2])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,31 @@
|
|||||||
|
name: "ernie_tokenizer"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 64
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT_0"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "OUTPUT_0"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OUTPUT_1"
|
||||||
|
data_type: TYPE_INT64
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
149
examples/text/ernie-3.0/serving/seq_cls_grpc_client.py
Executable file
149
examples/text/ernie-3.0/serving/seq_cls_grpc_client.py
Executable file
@@ -0,0 +1,149 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tritonclient import utils as client_utils
|
||||||
|
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
||||||
|
|
||||||
|
|
||||||
|
class SyncGRPCTritonRunner:
|
||||||
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
model_name: str,
|
||||||
|
model_version: str,
|
||||||
|
*,
|
||||||
|
verbose=False,
|
||||||
|
resp_wait_s: Optional[float]=None, ):
|
||||||
|
self._server_url = server_url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._model_version = model_version
|
||||||
|
self._verbose = verbose
|
||||||
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
||||||
|
|
||||||
|
self._client = InferenceServerClient(
|
||||||
|
self._server_url, verbose=self._verbose)
|
||||||
|
error = self._verify_triton_state(self._client)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not communicate to Triton Server: {error}")
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
||||||
|
f"are up and ready!")
|
||||||
|
|
||||||
|
model_config = self._client.get_model_config(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
model_metadata = self._client.get_model_metadata(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
LOGGER.info(f"Model config {model_config}")
|
||||||
|
LOGGER.info(f"Model metadata {model_metadata}")
|
||||||
|
|
||||||
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
||||||
|
self._input_names = list(self._inputs)
|
||||||
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
||||||
|
self._output_names = list(self._outputs)
|
||||||
|
self._outputs_req = [
|
||||||
|
InferRequestedOutput(name) for name in self._outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
def Run(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: list, Each value corresponds to an input name of self._input_names
|
||||||
|
Returns:
|
||||||
|
results: dict, {name : numpy.array}
|
||||||
|
"""
|
||||||
|
infer_inputs = []
|
||||||
|
for idx, data in enumerate(inputs):
|
||||||
|
data = np.array(
|
||||||
|
[[x.encode('utf-8')] for x in data], dtype=np.object_)
|
||||||
|
infer_input = InferInput(self._input_names[idx], [len(data), 1],
|
||||||
|
"BYTES")
|
||||||
|
infer_input.set_data_from_numpy(data)
|
||||||
|
infer_inputs.append(infer_input)
|
||||||
|
|
||||||
|
results = self._client.infer(
|
||||||
|
model_name=self._model_name,
|
||||||
|
model_version=self._model_version,
|
||||||
|
inputs=infer_inputs,
|
||||||
|
outputs=self._outputs_req,
|
||||||
|
client_timeout=self._response_wait_t, )
|
||||||
|
results = {name: results.as_numpy(name) for name in self._output_names}
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _verify_triton_state(self, triton_client):
|
||||||
|
if not triton_client.is_server_live():
|
||||||
|
return f"Triton server {self._server_url} is not live"
|
||||||
|
elif not triton_client.is_server_ready():
|
||||||
|
return f"Triton server {self._server_url} is not ready"
|
||||||
|
elif not triton_client.is_model_ready(self._model_name,
|
||||||
|
self._model_version):
|
||||||
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tnews_dataset(runner):
|
||||||
|
from paddlenlp.datasets import load_dataset
|
||||||
|
dev_ds = load_dataset('clue', "tnews", splits='dev')
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
labels = []
|
||||||
|
idx = 0
|
||||||
|
batch_size = 32
|
||||||
|
while idx < len(dev_ds):
|
||||||
|
data = []
|
||||||
|
label = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
if idx + i >= len(dev_ds):
|
||||||
|
break
|
||||||
|
data.append(dev_ds[idx + i]["sentence"])
|
||||||
|
label.append(dev_ds[idx + i]["label"])
|
||||||
|
batches.append(data)
|
||||||
|
labels.append(np.array(label))
|
||||||
|
idx += batch_size
|
||||||
|
|
||||||
|
accuracy = 0
|
||||||
|
for i, data in enumerate(batches):
|
||||||
|
ret = runner.Run([data])
|
||||||
|
# print("ret:", ret)
|
||||||
|
accuracy += np.sum(labels[i] == ret["label"])
|
||||||
|
print("acc:", 1.0 * accuracy / len(dev_ds))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from paddlenlp.datasets import load_dataset
|
||||||
|
dev_ds = load_dataset('clue', "tnews", splits='dev')
|
||||||
|
model_name = "ernie_seqcls"
|
||||||
|
model_version = "1"
|
||||||
|
url = "localhost:8001"
|
||||||
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
||||||
|
texts = [["你家拆迁,要钱还是要房?答案一目了然", "军嫂探亲拧包入住,部队家属临时来队房标准有了规定,全面落实!"], [
|
||||||
|
"区块链投资心得,能做到就不会亏钱",
|
||||||
|
]]
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# input format:[input1, input2 ... inputn], n = len(self._input_names)
|
||||||
|
result = runner.Run([text])
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
test_tnews_dataset(runner)
|
126
examples/text/ernie-3.0/serving/token_cls_grpc_client.py
Executable file
126
examples/text/ernie-3.0/serving/token_cls_grpc_client.py
Executable file
@@ -0,0 +1,126 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tritonclient import utils as client_utils
|
||||||
|
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
||||||
|
|
||||||
|
|
||||||
|
class SyncGRPCTritonRunner:
|
||||||
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
model_name: str,
|
||||||
|
model_version: str,
|
||||||
|
*,
|
||||||
|
verbose=False,
|
||||||
|
resp_wait_s: Optional[float]=None, ):
|
||||||
|
self._server_url = server_url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._model_version = model_version
|
||||||
|
self._verbose = verbose
|
||||||
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
||||||
|
|
||||||
|
self._client = InferenceServerClient(
|
||||||
|
self._server_url, verbose=self._verbose)
|
||||||
|
error = self._verify_triton_state(self._client)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not communicate to Triton Server: {error}")
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
||||||
|
f"are up and ready!")
|
||||||
|
|
||||||
|
model_config = self._client.get_model_config(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
model_metadata = self._client.get_model_metadata(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
LOGGER.info(f"Model config {model_config}")
|
||||||
|
LOGGER.info(f"Model metadata {model_metadata}")
|
||||||
|
|
||||||
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
||||||
|
self._input_names = list(self._inputs)
|
||||||
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
||||||
|
self._output_names = list(self._outputs)
|
||||||
|
self._outputs_req = [
|
||||||
|
InferRequestedOutput(name) for name in self._outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
def Run(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: list, Each value corresponds to an input name of self._input_names
|
||||||
|
Returns:
|
||||||
|
results: dict, {name : numpy.array}
|
||||||
|
"""
|
||||||
|
infer_inputs = []
|
||||||
|
for idx, data in enumerate(inputs):
|
||||||
|
data = np.array(
|
||||||
|
[[x.encode('utf-8')] for x in data], dtype=np.object_)
|
||||||
|
infer_input = InferInput(self._input_names[idx], [len(data), 1],
|
||||||
|
"BYTES")
|
||||||
|
infer_input.set_data_from_numpy(data)
|
||||||
|
infer_inputs.append(infer_input)
|
||||||
|
|
||||||
|
results = self._client.infer(
|
||||||
|
model_name=self._model_name,
|
||||||
|
model_version=self._model_version,
|
||||||
|
inputs=infer_inputs,
|
||||||
|
outputs=self._outputs_req,
|
||||||
|
client_timeout=self._response_wait_t, )
|
||||||
|
results = {name: results.as_numpy(name) for name in self._output_names}
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _verify_triton_state(self, triton_client):
|
||||||
|
if not triton_client.is_server_live():
|
||||||
|
return f"Triton server {self._server_url} is not live"
|
||||||
|
elif not triton_client.is_server_ready():
|
||||||
|
return f"Triton server {self._server_url} is not ready"
|
||||||
|
elif not triton_client.is_model_ready(self._model_name,
|
||||||
|
self._model_version):
|
||||||
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name = "ernie_tokencls"
|
||||||
|
model_version = "1"
|
||||||
|
url = "localhost:8001"
|
||||||
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
||||||
|
dataset = [[
|
||||||
|
"北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。",
|
||||||
|
"原产玛雅故国的玉米,早已成为华夏大地主要粮食作物之一。",
|
||||||
|
], ]
|
||||||
|
|
||||||
|
for batch_input in dataset:
|
||||||
|
# input format:[input1, input2 ... inputn], n = len(self._input_names)
|
||||||
|
result = runner.Run([batch_input])
|
||||||
|
for i, ret in enumerate(result['OUTPUT']):
|
||||||
|
ret = ast.literal_eval(ret.decode('utf-8'))
|
||||||
|
print("input data:", batch_input[i])
|
||||||
|
print("The model detects all entities:")
|
||||||
|
for iterm in ret:
|
||||||
|
print("entity:",
|
||||||
|
batch_input[i][iterm["pos"][0]:iterm["pos"][1] + 1],
|
||||||
|
" label:", iterm["label"], " pos:", iterm["pos"])
|
@@ -236,6 +236,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
ParseBoolValue(value_string, &pd_enable_mkldnn));
|
ParseBoolValue(value_string, &pd_enable_mkldnn));
|
||||||
runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn);
|
runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn);
|
||||||
|
} else if (param_key == "use_paddle_log") {
|
||||||
|
runtime_options_->EnablePaddleLogInfo();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -305,6 +307,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
} else if (value_string == "trt_int8") {
|
} else if (value_string == "trt_int8") {
|
||||||
// TODO(liqi): use EnableTrtINT8
|
// TODO(liqi): use EnableTrtINT8
|
||||||
runtime_options_->trt_enable_int8 = true;
|
runtime_options_->trt_enable_int8 = true;
|
||||||
|
} else if (value_string == "pd_fp16") {
|
||||||
|
// TODO(liqi): paddle inference don't currently have interface for fp16.
|
||||||
}
|
}
|
||||||
// } else if( param_key == "max_batch_size") {
|
// } else if( param_key == "max_batch_size") {
|
||||||
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
|
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
|
||||||
@@ -317,6 +321,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
runtime_options_->SetTrtCacheFile(value_string);
|
runtime_options_->SetTrtCacheFile(value_string);
|
||||||
} else if (param_key == "use_paddle") {
|
} else if (param_key == "use_paddle") {
|
||||||
runtime_options_->EnablePaddleToTrt();
|
runtime_options_->EnablePaddleToTrt();
|
||||||
|
} else if (param_key == "use_paddle_log") {
|
||||||
|
runtime_options_->EnablePaddleLogInfo();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user