mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[Serving]Add PPCls serving examples (#555)
* add ppcls serving examples * fix ppcls/serving docs * fix code style
This commit is contained in:
73
examples/vision/classification/paddleclas/serving/README.md
Normal file
73
examples/vision/classification/paddleclas/serving/README.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# PaddleClas 服务化部署示例
|
||||||
|
|
||||||
|
## 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/classification/paddleclas/serving
|
||||||
|
|
||||||
|
# 下载ResNet50_vd模型文件和测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
|
||||||
|
tar -xvf ResNet50_vd_infer.tgz
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
# 将配置文件放入预处理目录
|
||||||
|
mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/
|
||||||
|
|
||||||
|
# 将模型放入 models/runtime/1目录下, 并重命名为model.pdmodel和model.pdiparams
|
||||||
|
mv ResNet50_vd_infer/inference.pdmodel models/runtime/1/model.pdmodel
|
||||||
|
mv ResNet50_vd_infer/inference.pdiparams models/runtime/1/model.pdiparams
|
||||||
|
|
||||||
|
# 拉取fastdeploy镜像
|
||||||
|
# GPU镜像
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10
|
||||||
|
# CPU镜像
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.3.0-cpu-only-21.10
|
||||||
|
|
||||||
|
# 运行容器.容器名字为 fd_serving, 并挂载当前目录为容器的 /serving 目录
|
||||||
|
nvidia-docker run -it --net=host --name fd_serving -v `pwd`/:/serving paddlepaddle/fastdeploy:0.3.0-gpu-cuda11.4-trt8.4-21.10 bash
|
||||||
|
|
||||||
|
# 启动服务(不设置CUDA_VISIBLE_DEVICES环境变量,会拥有所有GPU卡的调度权限)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760
|
||||||
|
```
|
||||||
|
>> **注意**:
|
||||||
|
|
||||||
|
>> 拉取其他硬件上的镜像请看[服务化部署主文档](../../../../../serving/README.md)
|
||||||
|
|
||||||
|
>> 执行fastdeployserver启动服务出现"Address already in use", 请使用`--grpc-port`指定端口号来启动服务,同时更改客户端示例中的请求端口号.
|
||||||
|
|
||||||
|
>> 其他启动参数可以使用 fastdeployserver --help 查看
|
||||||
|
|
||||||
|
服务启动成功后, 会有以下输出:
|
||||||
|
```
|
||||||
|
......
|
||||||
|
I0928 04:51:15.784517 206 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
|
||||||
|
I0928 04:51:15.785177 206 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
|
||||||
|
I0928 04:51:15.826578 206 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 客户端请求
|
||||||
|
|
||||||
|
在物理机器中执行以下命令,发送grpc请求并输出结果
|
||||||
|
```
|
||||||
|
#下载测试图片
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
#安装客户端依赖
|
||||||
|
python3 -m pip install tritonclient\[all\]
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
python3 paddlecls_grpc_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
发送请求成功后,会返回json格式的检测结果并打印输出:
|
||||||
|
```
|
||||||
|
output_name: CLAS_RESULT
|
||||||
|
{'label_ids': [153], 'scores': [0.6862289905548096]}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置修改
|
||||||
|
|
||||||
|
当前默认配置在GPU上运行TensorRT引擎, 如果要在CPU或其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md)
|
@@ -0,0 +1,3 @@
|
|||||||
|
# PaddleCls Pipeline
|
||||||
|
|
||||||
|
The pipeline directory does not have model files, but a version number directory needs to be maintained.
|
@@ -0,0 +1,57 @@
|
|||||||
|
name: "paddlecls"
|
||||||
|
platform: "ensemble"
|
||||||
|
max_batch_size: 16
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT"
|
||||||
|
data_type: TYPE_UINT8
|
||||||
|
dims: [ -1, -1, 3 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "CLAS_RESULT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
ensemble_scheduling {
|
||||||
|
step [
|
||||||
|
{
|
||||||
|
model_name: "preprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "preprocess_input"
|
||||||
|
value: "INPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "preprocess_output"
|
||||||
|
value: "RUNTIME_INPUT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "runtime"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "inputs"
|
||||||
|
value: "RUNTIME_INPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "save_infer_model/scale_0.tmp_1"
|
||||||
|
value: "RUNTIME_OUTPUT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "postprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "post_input"
|
||||||
|
value: "RUNTIME_OUTPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "post_output"
|
||||||
|
value: "CLAS_RESULT"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("postprocess input names:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("postprocess output names:", self.output_names)
|
||||||
|
|
||||||
|
self.postprocess_ = fd.vision.classification.PaddleClasPostprocessor()
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
infer_outputs = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, self.input_names[0])
|
||||||
|
infer_outputs = infer_outputs.as_numpy()
|
||||||
|
|
||||||
|
results = self.postprocess_.run([infer_outputs, ])
|
||||||
|
r_str = fd.vision.utils.fd_result_to_json(results)
|
||||||
|
|
||||||
|
r_np = np.array(r_str, dtype=np.object)
|
||||||
|
out_tensor = pb_utils.Tensor(self.output_names[0], r_np)
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[out_tensor, ])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,26 @@
|
|||||||
|
name: "postprocess"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 16
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "post_input"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1000 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "post_output"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,35 @@
|
|||||||
|
Global:
|
||||||
|
infer_imgs: "./images/ImageNet/ILSVRC2012_val_00000010.jpeg"
|
||||||
|
inference_model_dir: "./models"
|
||||||
|
batch_size: 1
|
||||||
|
use_gpu: True
|
||||||
|
enable_mkldnn: True
|
||||||
|
cpu_num_threads: 10
|
||||||
|
enable_benchmark: True
|
||||||
|
use_fp16: False
|
||||||
|
ir_optim: True
|
||||||
|
use_tensorrt: False
|
||||||
|
gpu_mem: 8000
|
||||||
|
enable_profile: False
|
||||||
|
|
||||||
|
PreProcess:
|
||||||
|
transform_ops:
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
channel_num: 3
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
main_indicator: Topk
|
||||||
|
Topk:
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: "../ppcls/utils/imagenet1k_label_list.txt"
|
||||||
|
SavePreLabel:
|
||||||
|
save_dir: ./pre_label/
|
@@ -0,0 +1,113 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("preprocess input names:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
# dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
# self.output_dtype.append(dtype)
|
||||||
|
self.output_dtype.append(output_config["data_type"])
|
||||||
|
print("preprocess output names:", self.output_names)
|
||||||
|
|
||||||
|
# init PaddleClasPreprocess class
|
||||||
|
yaml_path = os.path.abspath(os.path.dirname(
|
||||||
|
__file__)) + "/inference_cls.yaml"
|
||||||
|
self.preprocess_ = fd.vision.classification.PaddleClasPreprocessor(
|
||||||
|
yaml_path)
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
for request in requests:
|
||||||
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
data = data.as_numpy()
|
||||||
|
outputs = self.preprocess_.run(data)
|
||||||
|
|
||||||
|
# PaddleCls preprocess has only one output
|
||||||
|
dlpack_tensor = outputs[0].to_dlpack()
|
||||||
|
output_tensor = pb_utils.Tensor.from_dlpack(self.output_names[0],
|
||||||
|
dlpack_tensor)
|
||||||
|
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[output_tensor, ])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,26 @@
|
|||||||
|
name: "preprocess"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 16
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "preprocess_input"
|
||||||
|
data_type: TYPE_UINT8
|
||||||
|
dims: [ -1, -1, 3 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "preprocess_output"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 3, 224, 224 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,5 @@
|
|||||||
|
# Runtime Directory
|
||||||
|
|
||||||
|
This directory holds the model files.
|
||||||
|
Paddle models must be model.pdmodel and model.pdiparams files.
|
||||||
|
ONNX models must be model.onnx files.
|
@@ -0,0 +1,60 @@
|
|||||||
|
# optional, If name is specified it must match the name of the model repository directory containing the model.
|
||||||
|
name: "runtime"
|
||||||
|
backend: "fastdeploy"
|
||||||
|
max_batch_size: 16
|
||||||
|
|
||||||
|
# Input configuration of the model
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
# input name
|
||||||
|
name: "inputs"
|
||||||
|
# input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
# input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w]
|
||||||
|
dims: [ 3, 224, 224 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# The output of the model is configured in the same format as the input
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "save_infer_model/scale_0.tmp_1"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 1000 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Number of instances of the model
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
# The number of instances is 1
|
||||||
|
count: 1
|
||||||
|
# Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
kind: KIND_GPU
|
||||||
|
# The instance is deployed on the 0th GPU card
|
||||||
|
gpus: [0]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
gpu_execution_accelerator : [ {
|
||||||
|
# use TRT engine
|
||||||
|
name: "tensorrt",
|
||||||
|
# use fp16 on TRT engine
|
||||||
|
parameters { key: "precision" value: "trt_fp16" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "min_shape"
|
||||||
|
parameters { key: "inputs" value: "1 3 224 224" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "opt_shape"
|
||||||
|
parameters { key: "inputs" value: "1 3 224 224" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max_shape"
|
||||||
|
parameters { key: "inputs" value: "16 3 224 224" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}}
|
@@ -0,0 +1,109 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tritonclient import utils as client_utils
|
||||||
|
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
||||||
|
|
||||||
|
|
||||||
|
class SyncGRPCTritonRunner:
|
||||||
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
model_name: str,
|
||||||
|
model_version: str,
|
||||||
|
*,
|
||||||
|
verbose=False,
|
||||||
|
resp_wait_s: Optional[float]=None, ):
|
||||||
|
self._server_url = server_url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._model_version = model_version
|
||||||
|
self._verbose = verbose
|
||||||
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
||||||
|
|
||||||
|
self._client = InferenceServerClient(
|
||||||
|
self._server_url, verbose=self._verbose)
|
||||||
|
error = self._verify_triton_state(self._client)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not communicate to Triton Server: {error}")
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
||||||
|
f"are up and ready!")
|
||||||
|
|
||||||
|
model_config = self._client.get_model_config(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
model_metadata = self._client.get_model_metadata(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
LOGGER.info(f"Model config {model_config}")
|
||||||
|
LOGGER.info(f"Model metadata {model_metadata}")
|
||||||
|
|
||||||
|
for tm in model_metadata.inputs:
|
||||||
|
print("tm:", tm)
|
||||||
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
||||||
|
self._input_names = list(self._inputs)
|
||||||
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
||||||
|
self._output_names = list(self._outputs)
|
||||||
|
self._outputs_req = [
|
||||||
|
InferRequestedOutput(name) for name in self._outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
def Run(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: list, Each value corresponds to an input name of self._input_names
|
||||||
|
Returns:
|
||||||
|
results: dict, {name : numpy.array}
|
||||||
|
"""
|
||||||
|
infer_inputs = []
|
||||||
|
for idx, data in enumerate(inputs):
|
||||||
|
infer_input = InferInput(self._input_names[idx], data.shape,
|
||||||
|
"UINT8")
|
||||||
|
infer_input.set_data_from_numpy(data)
|
||||||
|
infer_inputs.append(infer_input)
|
||||||
|
|
||||||
|
results = self._client.infer(
|
||||||
|
model_name=self._model_name,
|
||||||
|
model_version=self._model_version,
|
||||||
|
inputs=infer_inputs,
|
||||||
|
outputs=self._outputs_req,
|
||||||
|
client_timeout=self._response_wait_t, )
|
||||||
|
results = {name: results.as_numpy(name) for name in self._output_names}
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _verify_triton_state(self, triton_client):
|
||||||
|
if not triton_client.is_server_live():
|
||||||
|
return f"Triton server {self._server_url} is not live"
|
||||||
|
elif not triton_client.is_server_ready():
|
||||||
|
return f"Triton server {self._server_url} is not ready"
|
||||||
|
elif not triton_client.is_model_ready(self._model_name,
|
||||||
|
self._model_version):
|
||||||
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name = "paddlecls"
|
||||||
|
model_version = "1"
|
||||||
|
url = "localhost:8001"
|
||||||
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
||||||
|
im = cv2.imread("ILSVRC2012_val_00000010.jpeg")
|
||||||
|
im = np.array([im, ])
|
||||||
|
# batch input
|
||||||
|
# im = np.array([im, im, im])
|
||||||
|
for i in range(1):
|
||||||
|
result = runner.Run([im, ])
|
||||||
|
for name, values in result.items():
|
||||||
|
print("output_name:", name)
|
||||||
|
# values is batch
|
||||||
|
for value in values:
|
||||||
|
value = json.loads(value)
|
||||||
|
print(value)
|
@@ -12,11 +12,162 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <dlpack/dlpack.h>
|
||||||
|
|
||||||
|
#include "fastdeploy/core/fd_type.h"
|
||||||
|
#include "fastdeploy/utils/utils.h"
|
||||||
#include "fastdeploy/fastdeploy_model.h"
|
#include "fastdeploy/fastdeploy_model.h"
|
||||||
#include "fastdeploy/pybind/main.h"
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
DLDataType FDToDlpackType(FDDataType fd_dtype) {
|
||||||
|
DLDataType dl_dtype;
|
||||||
|
DLDataTypeCode dl_code;
|
||||||
|
|
||||||
|
// Number of bits required for the data type.
|
||||||
|
size_t dt_size = 0;
|
||||||
|
|
||||||
|
dl_dtype.lanes = 1;
|
||||||
|
switch (fd_dtype) {
|
||||||
|
case FDDataType::BOOL:
|
||||||
|
dl_code = DLDataTypeCode::kDLInt;
|
||||||
|
dt_size = 1;
|
||||||
|
break;
|
||||||
|
case FDDataType::UINT8:
|
||||||
|
dl_code = DLDataTypeCode::kDLUInt;
|
||||||
|
dt_size = 8;
|
||||||
|
break;
|
||||||
|
case FDDataType::INT8:
|
||||||
|
dl_code = DLDataTypeCode::kDLInt;
|
||||||
|
dt_size = 8;
|
||||||
|
break;
|
||||||
|
case FDDataType::INT16:
|
||||||
|
dl_code = DLDataTypeCode::kDLInt;
|
||||||
|
dt_size = 16;
|
||||||
|
break;
|
||||||
|
case FDDataType::INT32:
|
||||||
|
dl_code = DLDataTypeCode::kDLInt;
|
||||||
|
dt_size = 32;
|
||||||
|
break;
|
||||||
|
case FDDataType::INT64:
|
||||||
|
dl_code = DLDataTypeCode::kDLInt;
|
||||||
|
dt_size = 64;
|
||||||
|
break;
|
||||||
|
case FDDataType::FP16:
|
||||||
|
dl_code = DLDataTypeCode::kDLFloat;
|
||||||
|
dt_size = 16;
|
||||||
|
break;
|
||||||
|
case FDDataType::FP32:
|
||||||
|
dl_code = DLDataTypeCode::kDLFloat;
|
||||||
|
dt_size = 32;
|
||||||
|
break;
|
||||||
|
case FDDataType::FP64:
|
||||||
|
dl_code = DLDataTypeCode::kDLFloat;
|
||||||
|
dt_size = 64;
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FDASSERT(false,
|
||||||
|
"Convert to DlPack, FDType \"%s\" is not supported.", Str(fd_dtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
dl_dtype.code = dl_code;
|
||||||
|
dl_dtype.bits = dt_size;
|
||||||
|
return dl_dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
FDDataType
|
||||||
|
DlpackToFDType(const DLDataType& data_type) {
|
||||||
|
FDASSERT(data_type.lanes == 1,
|
||||||
|
"FDTensor does not support dlpack lanes != 1")
|
||||||
|
|
||||||
|
if (data_type.code == DLDataTypeCode::kDLFloat) {
|
||||||
|
if (data_type.bits == 16) {
|
||||||
|
return FDDataType::FP16;
|
||||||
|
} else if (data_type.bits == 32) {
|
||||||
|
return FDDataType::FP32;
|
||||||
|
} else if (data_type.bits == 64) {
|
||||||
|
return FDDataType::FP64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_type.code == DLDataTypeCode::kDLInt) {
|
||||||
|
if (data_type.bits == 8) {
|
||||||
|
return FDDataType::INT8;
|
||||||
|
} else if (data_type.bits == 16) {
|
||||||
|
return FDDataType::INT16;
|
||||||
|
} else if (data_type.bits == 32) {
|
||||||
|
return FDDataType::INT32;
|
||||||
|
} else if (data_type.bits == 64) {
|
||||||
|
return FDDataType::INT64;
|
||||||
|
} else if (data_type.bits == 1) {
|
||||||
|
return FDDataType::BOOL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_type.code == DLDataTypeCode::kDLUInt) {
|
||||||
|
if (data_type.bits == 8) {
|
||||||
|
return FDDataType::UINT8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return FDDataType::UNKNOWN1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteUnusedDltensor(PyObject* dlp) {
|
||||||
|
if (PyCapsule_IsValid(dlp, "dltensor")) {
|
||||||
|
DLManagedTensor* dl_managed_tensor =
|
||||||
|
static_cast<DLManagedTensor*>(PyCapsule_GetPointer(dlp, "dltensor"));
|
||||||
|
dl_managed_tensor->deleter(dl_managed_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pybind11::capsule FDTensorToDLPack(FDTensor& fd_tensor) {
|
||||||
|
DLManagedTensor* dlpack_tensor = new DLManagedTensor;
|
||||||
|
dlpack_tensor->dl_tensor.ndim = fd_tensor.shape.size();
|
||||||
|
dlpack_tensor->dl_tensor.byte_offset = 0;
|
||||||
|
dlpack_tensor->dl_tensor.data = fd_tensor.MutableData();
|
||||||
|
dlpack_tensor->dl_tensor.shape = &(fd_tensor.shape[0]);
|
||||||
|
dlpack_tensor->dl_tensor.strides = nullptr;
|
||||||
|
dlpack_tensor->manager_ctx = &fd_tensor;
|
||||||
|
dlpack_tensor->deleter = [](DLManagedTensor* m) {
|
||||||
|
if (m->manager_ctx == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
FDTensor* tensor_ptr = reinterpret_cast<FDTensor*>(m->manager_ctx);
|
||||||
|
pybind11::handle tensor_handle = pybind11::cast(tensor_ptr);
|
||||||
|
tensor_handle.dec_ref();
|
||||||
|
free(m);
|
||||||
|
};
|
||||||
|
|
||||||
|
pybind11::handle tensor_handle = pybind11::cast(&fd_tensor);
|
||||||
|
|
||||||
|
// Increase the reference count by one to make sure that the DLPack
|
||||||
|
// represenation doesn't become invalid when the tensor object goes out of
|
||||||
|
// scope.
|
||||||
|
tensor_handle.inc_ref();
|
||||||
|
|
||||||
|
dlpack_tensor->dl_tensor.dtype = FDToDlpackType(fd_tensor.dtype);
|
||||||
|
|
||||||
|
// TODO(liqi): FDTensor add device_id
|
||||||
|
dlpack_tensor->dl_tensor.device.device_id = 0;
|
||||||
|
if(fd_tensor.device == Device::GPU) {
|
||||||
|
if (fd_tensor.is_pinned_memory) {
|
||||||
|
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDAHost;
|
||||||
|
} else {
|
||||||
|
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDA;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pybind11::capsule(
|
||||||
|
static_cast<void*>(dlpack_tensor), "dltensor", &DeleteUnusedDltensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void BindFDTensor(pybind11::module& m) {
|
void BindFDTensor(pybind11::module& m) {
|
||||||
pybind11::class_<FDTensor>(m, "FDTensor")
|
pybind11::class_<FDTensor>(m, "FDTensor")
|
||||||
.def(pybind11::init<>(), "Default Constructor")
|
.def(pybind11::init<>(), "Default Constructor")
|
||||||
@@ -27,9 +178,11 @@ void BindFDTensor(pybind11::module& m) {
|
|||||||
.def("numpy", [](FDTensor& self) {
|
.def("numpy", [](FDTensor& self) {
|
||||||
return TensorToPyArray(self);
|
return TensorToPyArray(self);
|
||||||
})
|
})
|
||||||
|
.def("data", &FDTensor::MutableData)
|
||||||
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
|
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
|
||||||
PyArrayToTensor(pyarray, &self, share_buffer);
|
PyArrayToTensor(pyarray, &self, share_buffer);
|
||||||
});
|
})
|
||||||
|
.def("to_dlpack", &FDTensorToDLPack);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -38,11 +38,26 @@ def detection_to_json(result):
|
|||||||
return json.dumps(r_json)
|
return json.dumps(r_json)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_to_json(result):
|
||||||
|
r_json = {
|
||||||
|
"label_ids": result.label_ids,
|
||||||
|
"scores": result.scores,
|
||||||
|
}
|
||||||
|
return json.dumps(r_json)
|
||||||
|
|
||||||
|
|
||||||
def fd_result_to_json(result):
|
def fd_result_to_json(result):
|
||||||
if isinstance(result, C.vision.DetectionResult):
|
if isinstance(result, list):
|
||||||
|
r_list = []
|
||||||
|
for r in result:
|
||||||
|
r_list.append(fd_result_to_json(r))
|
||||||
|
return r_list
|
||||||
|
elif isinstance(result, C.vision.DetectionResult):
|
||||||
return detection_to_json(result)
|
return detection_to_json(result)
|
||||||
elif isinstance(result, C.vision.Mask):
|
elif isinstance(result, C.vision.Mask):
|
||||||
return mask_to_json(result)
|
return mask_to_json(result)
|
||||||
|
elif isinstance(result, C.vision.ClassifyResult):
|
||||||
|
return classify_to_json(result)
|
||||||
else:
|
else:
|
||||||
assert False, "{} Conversion to JSON format is not supported".format(
|
assert False, "{} Conversion to JSON format is not supported".format(
|
||||||
type(result))
|
type(result))
|
||||||
|
Reference in New Issue
Block a user