mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fastdeploy support serving (#272)
* fd support serving * fd support serving optimize dir * optimize code Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
19
examples/vision/detection/yolov5/serving/README.md
Normal file
19
examples/vision/detection/yolov5/serving/README.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# YOLOv5 Serving部署示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#下载yolov5模型文件和测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s.onnx
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||||
|
|
||||||
|
# 将模型放入 models/infer/1目录下, 并重命名为model.onnx
|
||||||
|
mv yolov5s.onnx models/infer/1/
|
||||||
|
|
||||||
|
# 拉取fastdeploy镜像
|
||||||
|
docker pull xxx
|
||||||
|
|
||||||
|
# 启动镜像和服务
|
||||||
|
docker run xx
|
||||||
|
|
||||||
|
# 客户端请求
|
||||||
|
python yolov5_grpc_client.py
|
||||||
|
```
|
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("postprocess input names:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("postprocess output names:", self.output_names)
|
||||||
|
|
||||||
|
def yolov5_postprocess(self, infer_outputs, im_infos):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
infer_outputs : numpy.array
|
||||||
|
Contains the batch of inference results
|
||||||
|
im_infos : numpy.array(b'{}')
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
numpy.array
|
||||||
|
yolov5 postprocess result
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for i_batch in range(len(im_infos)):
|
||||||
|
new_infer_output = infer_outputs[i_batch:i_batch + 1]
|
||||||
|
new_im_info = im_infos[i_batch].decode('utf-8').replace("'", '"')
|
||||||
|
new_im_info = json.loads(new_im_info)
|
||||||
|
|
||||||
|
result = fd.vision.detection.YOLOv5.postprocess(
|
||||||
|
[new_infer_output, ], new_im_info)
|
||||||
|
|
||||||
|
r_str = fd.vision.utils.fd_result_to_json(result)
|
||||||
|
results.append(r_str)
|
||||||
|
return np.array(results, dtype=np.object)
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
infer_outputs = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, self.input_names[0])
|
||||||
|
im_infos = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[1])
|
||||||
|
infer_outputs = infer_outputs.as_numpy()
|
||||||
|
im_infos = im_infos.as_numpy()
|
||||||
|
|
||||||
|
results = self.yolov5_postprocess(infer_outputs, im_infos)
|
||||||
|
|
||||||
|
out_tensor = pb_utils.Tensor(self.output_names[0], results)
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[out_tensor, ])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,30 @@
|
|||||||
|
name: "postprocess"
|
||||||
|
backend: "python"
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "POST_INPUT_0"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1, -1, -1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "POST_INPUT_1"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "POST_OUTPUT"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,120 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
import fastdeploy as fd
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("preprocess input names:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("preprocess output names:", self.output_names)
|
||||||
|
|
||||||
|
def yolov5_preprocess(self, input_data):
|
||||||
|
"""
|
||||||
|
According to Triton input, the preprocessing results of YoloV5 model are obtained.
|
||||||
|
"""
|
||||||
|
im_infos = []
|
||||||
|
pre_outputs = []
|
||||||
|
for i_batch in input_data:
|
||||||
|
pre_output, im_info = fd.vision.detection.YOLOv5.preprocess(
|
||||||
|
i_batch)
|
||||||
|
pre_outputs.append(pre_output)
|
||||||
|
im_infos.append(im_info)
|
||||||
|
im_infos = np.array(im_infos, dtype=np.object)
|
||||||
|
pre_outputs = np.concatenate(pre_outputs, axis=0)
|
||||||
|
return pre_outputs, im_infos
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# print("num:", len(requests), flush=True)
|
||||||
|
for request in requests:
|
||||||
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
data = data.as_numpy()
|
||||||
|
outputs = self.yolov5_preprocess(data)
|
||||||
|
output_tensors = []
|
||||||
|
for idx, output in enumerate(outputs):
|
||||||
|
output_tensors.append(
|
||||||
|
pb_utils.Tensor(self.output_names[idx], output))
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=output_tensors)
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
@@ -0,0 +1,31 @@
|
|||||||
|
name: "preprocess"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 1
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT_0"
|
||||||
|
data_type: TYPE_UINT8
|
||||||
|
dims: [ -1, -1, 3 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "preprocess_output_0"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ 3, -1, -1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preprocess_output_1"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,3 @@
|
|||||||
|
# Runtime Directory
|
||||||
|
|
||||||
|
导出的部署模型需要放在本目录下
|
@@ -0,0 +1,38 @@
|
|||||||
|
# optional, If name is specified it must match the name of the model repository directory containing the model.
|
||||||
|
name: "runtime"
|
||||||
|
backend: "fastdeploy"
|
||||||
|
max_batch_size: 16
|
||||||
|
|
||||||
|
# Input configuration of the model
|
||||||
|
input [
|
||||||
|
# 第一个输入
|
||||||
|
{
|
||||||
|
# input name
|
||||||
|
name: "images"
|
||||||
|
# input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
# input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w]
|
||||||
|
dims: [ 3, -1, -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# The output of the model is configured in the same format as the input
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "output"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1, -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Number of instances of the model
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
# The number of instances is 1
|
||||||
|
count: 1
|
||||||
|
# Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
kind: KIND_GPU
|
||||||
|
# The instance is deployed on the 0th GPU card
|
||||||
|
gpus: [0]
|
||||||
|
}
|
||||||
|
]
|
@@ -0,0 +1,3 @@
|
|||||||
|
# YOLOV5 Pipeline
|
||||||
|
|
||||||
|
The pipeline directory does not have model files, but a version number directory needs to be maintained.
|
@@ -0,0 +1,65 @@
|
|||||||
|
name: "yolov5"
|
||||||
|
platform: "ensemble"
|
||||||
|
max_batch_size: 1
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT"
|
||||||
|
data_type: TYPE_UINT8
|
||||||
|
dims: [ -1, -1, 3 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "detction_result"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
ensemble_scheduling {
|
||||||
|
step [
|
||||||
|
{
|
||||||
|
model_name: "preprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "INPUT_0"
|
||||||
|
value: "INPUT"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "preprocess_output_0"
|
||||||
|
value: "infer_input"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "preprocess_output_1"
|
||||||
|
value: "postprocess_input_1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "runtime"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "images"
|
||||||
|
value: "infer_input"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "output"
|
||||||
|
value: "infer_output"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "postprocess"
|
||||||
|
model_version: 1
|
||||||
|
input_map {
|
||||||
|
key: "POST_INPUT_0"
|
||||||
|
value: "infer_output"
|
||||||
|
}
|
||||||
|
input_map {
|
||||||
|
key: "POST_INPUT_1"
|
||||||
|
value: "postprocess_input_1"
|
||||||
|
}
|
||||||
|
output_map {
|
||||||
|
key: "POST_OUTPUT"
|
||||||
|
value: "detction_result"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
112
examples/vision/detection/yolov5/serving/yolov5_grpc_client.py
Normal file
112
examples/vision/detection/yolov5/serving/yolov5_grpc_client.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tritonclient import utils as client_utils
|
||||||
|
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
||||||
|
|
||||||
|
|
||||||
|
class SyncGRPCTritonRunner:
|
||||||
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
model_name: str,
|
||||||
|
model_version: str,
|
||||||
|
*,
|
||||||
|
verbose=False,
|
||||||
|
resp_wait_s: Optional[float]=None, ):
|
||||||
|
self._server_url = server_url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._model_version = model_version
|
||||||
|
self._verbose = verbose
|
||||||
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
||||||
|
|
||||||
|
self._client = InferenceServerClient(
|
||||||
|
self._server_url, verbose=self._verbose)
|
||||||
|
error = self._verify_triton_state(self._client)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not communicate to Triton Server: {error}")
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
||||||
|
f"are up and ready!")
|
||||||
|
|
||||||
|
model_config = self._client.get_model_config(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
model_metadata = self._client.get_model_metadata(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
LOGGER.info(f"Model config {model_config}")
|
||||||
|
LOGGER.info(f"Model metadata {model_metadata}")
|
||||||
|
|
||||||
|
for tm in model_metadata.inputs:
|
||||||
|
print("tm:", tm)
|
||||||
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
||||||
|
self._input_names = list(self._inputs)
|
||||||
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
||||||
|
self._output_names = list(self._outputs)
|
||||||
|
self._outputs_req = [
|
||||||
|
InferRequestedOutput(name) for name in self._outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
def Run(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: list, Each value corresponds to an input name of self._input_names
|
||||||
|
Returns:
|
||||||
|
results: dict, {name : numpy.array}
|
||||||
|
"""
|
||||||
|
infer_inputs = []
|
||||||
|
for idx, data in enumerate(inputs):
|
||||||
|
print("len(data):", len(data))
|
||||||
|
print("name:", self._input_names[idx], " shape:", data.shape,
|
||||||
|
data.dtype)
|
||||||
|
#data = np.array([[x.encode('utf-8')] for x in data],
|
||||||
|
# dtype=np.object_)
|
||||||
|
infer_input = InferInput(self._input_names[idx], data.shape,
|
||||||
|
"UINT8")
|
||||||
|
infer_input.set_data_from_numpy(data)
|
||||||
|
infer_inputs.append(infer_input)
|
||||||
|
|
||||||
|
results = self._client.infer(
|
||||||
|
model_name=self._model_name,
|
||||||
|
model_version=self._model_version,
|
||||||
|
inputs=infer_inputs,
|
||||||
|
outputs=self._outputs_req,
|
||||||
|
client_timeout=self._response_wait_t, )
|
||||||
|
results = {name: results.as_numpy(name) for name in self._output_names}
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _verify_triton_state(self, triton_client):
|
||||||
|
if not triton_client.is_server_live():
|
||||||
|
return f"Triton server {self._server_url} is not live"
|
||||||
|
elif not triton_client.is_server_ready():
|
||||||
|
return f"Triton server {self._server_url} is not ready"
|
||||||
|
elif not triton_client.is_model_ready(self._model_name,
|
||||||
|
self._model_version):
|
||||||
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name = "yolov5"
|
||||||
|
model_version = "1"
|
||||||
|
url = "localhost:8001"
|
||||||
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
||||||
|
im = cv2.imread("000000014439.jpg")
|
||||||
|
im = np.array([im, ])
|
||||||
|
for i in range(1):
|
||||||
|
result = runner.Run([im, ])
|
||||||
|
for name, values in result.items():
|
||||||
|
print("output_name:", name)
|
||||||
|
for i in range(len(values)):
|
||||||
|
value = values[i][0]
|
||||||
|
value = json.loads(value)
|
||||||
|
print(value)
|
@@ -22,4 +22,5 @@ from . import facedet
|
|||||||
from . import faceid
|
from . import faceid
|
||||||
from . import ocr
|
from . import ocr
|
||||||
from . import evaluation
|
from . import evaluation
|
||||||
|
from .utils import fd_result_to_json
|
||||||
from .visualize import *
|
from .visualize import *
|
||||||
|
49
python/fastdeploy/vision/utils.py
Normal file
49
python/fastdeploy/vision/utils.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import json
|
||||||
|
from .. import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_json(result):
|
||||||
|
r_json = {
|
||||||
|
"data": result.data,
|
||||||
|
"shape": result.shape,
|
||||||
|
}
|
||||||
|
return json.dumps(r_json)
|
||||||
|
|
||||||
|
|
||||||
|
def detection_to_json(result):
|
||||||
|
masks = []
|
||||||
|
for mask in result.masks:
|
||||||
|
masks.append(mask_to_json(mask))
|
||||||
|
r_json = {
|
||||||
|
"boxes": result.boxes,
|
||||||
|
"scores": result.scores,
|
||||||
|
"label_ids": result.label_ids,
|
||||||
|
"masks": masks,
|
||||||
|
"contain_masks": result.contain_masks
|
||||||
|
}
|
||||||
|
return json.dumps(r_json)
|
||||||
|
|
||||||
|
|
||||||
|
def fd_result_to_json(result):
|
||||||
|
if isinstance(result, C.vision.DetectionResult):
|
||||||
|
return detection_to_json(result)
|
||||||
|
elif isinstance(result, C.vision.Mask):
|
||||||
|
return mask_to_json(result)
|
||||||
|
else:
|
||||||
|
assert False, "{} Conversion to JSON format is not supported".format(
|
||||||
|
type(result))
|
||||||
|
return {}
|
109
serving/CMakeLists.txt
Normal file
109
serving/CMakeLists.txt
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.17)
|
||||||
|
|
||||||
|
project(trironpaddlebackend LANGUAGES C CXX)
|
||||||
|
|
||||||
|
set(FASTDEPLOY_DIR "" CACHE PATH "Paths to FastDeploy Directory. Multiple paths may be specified by sparating them with a semicolon.")
|
||||||
|
set(FASTDEPLOY_INCLUDE_PATHS "${FASTDEPLOY_DIR}/include"
|
||||||
|
CACHE PATH "Paths to FastDeploy includes. Multiple paths may be specified by sparating them with a semicolon.")
|
||||||
|
set(FASTDEPLOY_LIB_PATHS "${FASTDEPLOY_DIR}/lib"
|
||||||
|
CACHE PATH "Paths to FastDeploy libraries. Multiple paths may be specified by sparating them with a semicolon.")
|
||||||
|
set(FASTDEPLOY_LIB_NAME "fastdeploy_runtime")
|
||||||
|
|
||||||
|
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||||
|
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||||
|
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
repo-common
|
||||||
|
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||||
|
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||||
|
GIT_SHALLOW ON
|
||||||
|
)
|
||||||
|
FetchContent_Declare(
|
||||||
|
repo-core
|
||||||
|
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||||
|
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||||
|
GIT_SHALLOW ON
|
||||||
|
)
|
||||||
|
FetchContent_Declare(
|
||||||
|
repo-backend
|
||||||
|
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||||
|
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||||
|
GIT_SHALLOW ON
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||||
|
|
||||||
|
configure_file(src/libtriton_fastdeploy.ldscript libtriton_fastdeploy.ldscript COPYONLY)
|
||||||
|
|
||||||
|
add_library(
|
||||||
|
triton-fastdeploy-backend SHARED
|
||||||
|
src/fastdeploy_runtime.cc
|
||||||
|
src/fastdeploy_backend_utils.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(
|
||||||
|
triton-fastdeploy-backend
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/src
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(
|
||||||
|
triton-fastdeploy-backend
|
||||||
|
PRIVATE ${FASTDEPLOY_INCLUDE_PATHS}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(
|
||||||
|
triton-fastdeploy-backend
|
||||||
|
PRIVATE "-L${FASTDEPLOY_LIB_PATHS} -l${FASTDEPLOY_LIB_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_features(triton-fastdeploy-backend PRIVATE cxx_std_11)
|
||||||
|
target_compile_options(
|
||||||
|
triton-fastdeploy-backend PRIVATE
|
||||||
|
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
|
||||||
|
-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror>
|
||||||
|
)
|
||||||
|
|
||||||
|
set_target_properties(
|
||||||
|
triton-fastdeploy-backend PROPERTIES
|
||||||
|
POSITION_INDEPENDENT_CODE ON
|
||||||
|
OUTPUT_NAME triton_fastdeploy
|
||||||
|
SKIP_BUILD_RPATH TRUE
|
||||||
|
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastdeploy.ldscript
|
||||||
|
LINK_FLAGS "-Wl,--version-script libtriton_fastdeploy.ldscript"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(
|
||||||
|
triton-fastdeploy-backend
|
||||||
|
PRIVATE
|
||||||
|
triton-backend-utils # from repo-backend
|
||||||
|
triton-core-serverstub # from repo-core
|
||||||
|
)
|
128
serving/src/fastdeploy_backend_utils.cc
Normal file
128
serving/src/fastdeploy_backend_utils.cc
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions
|
||||||
|
// are met:
|
||||||
|
// * Redistributions of source code must retain the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer.
|
||||||
|
// * Redistributions in binary form must reproduce the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer in the
|
||||||
|
// documentation and/or other materials provided with the distribution.
|
||||||
|
// * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
// contributors may be used to endorse or promote products derived
|
||||||
|
// from this software without specific prior written permission.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
#include "fastdeploy_backend_utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <iterator>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace backend {
|
||||||
|
namespace fastdeploy_runtime {
|
||||||
|
|
||||||
|
TRITONSERVER_DataType ConvertFDType(fastdeploy::FDDataType dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case fastdeploy::FDDataType::UNKNOWN1:
|
||||||
|
return TRITONSERVER_TYPE_INVALID;
|
||||||
|
case ::fastdeploy::FDDataType::UINT8:
|
||||||
|
return TRITONSERVER_TYPE_UINT8;
|
||||||
|
case ::fastdeploy::FDDataType::INT8:
|
||||||
|
return TRITONSERVER_TYPE_INT8;
|
||||||
|
case ::fastdeploy::FDDataType::INT32:
|
||||||
|
return TRITONSERVER_TYPE_INT32;
|
||||||
|
case ::fastdeploy::FDDataType::INT64:
|
||||||
|
return TRITONSERVER_TYPE_INT64;
|
||||||
|
case ::fastdeploy::FDDataType::FP32:
|
||||||
|
return TRITONSERVER_TYPE_FP32;
|
||||||
|
case ::fastdeploy::FDDataType::FP16:
|
||||||
|
return TRITONSERVER_TYPE_FP16;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return TRITONSERVER_TYPE_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
fastdeploy::FDDataType ConvertDataTypeToFD(TRITONSERVER_DataType dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case TRITONSERVER_TYPE_INVALID:
|
||||||
|
return ::fastdeploy::FDDataType::UNKNOWN1;
|
||||||
|
case TRITONSERVER_TYPE_UINT8:
|
||||||
|
return ::fastdeploy::FDDataType::UINT8;
|
||||||
|
case TRITONSERVER_TYPE_INT8:
|
||||||
|
return ::fastdeploy::FDDataType::INT8;
|
||||||
|
case TRITONSERVER_TYPE_INT32:
|
||||||
|
return ::fastdeploy::FDDataType::INT32;
|
||||||
|
case TRITONSERVER_TYPE_INT64:
|
||||||
|
return ::fastdeploy::FDDataType::INT64;
|
||||||
|
case TRITONSERVER_TYPE_FP32:
|
||||||
|
return ::fastdeploy::FDDataType::FP32;
|
||||||
|
case TRITONSERVER_TYPE_FP16:
|
||||||
|
return ::fastdeploy::FDDataType::FP16;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return ::fastdeploy::FDDataType::UNKNOWN1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fastdeploy::FDDataType ModelConfigDataTypeToFDType(
|
||||||
|
const std::string& data_type_str) {
|
||||||
|
// Must start with "TYPE_".
|
||||||
|
if (data_type_str.rfind("TYPE_", 0) != 0) {
|
||||||
|
return fastdeploy::FDDataType::UNKNOWN1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string dtype = data_type_str.substr(strlen("TYPE_"));
|
||||||
|
|
||||||
|
if (dtype == "UINT8") {
|
||||||
|
return fastdeploy::FDDataType::UINT8;
|
||||||
|
} else if (dtype == "INT8") {
|
||||||
|
return fastdeploy::FDDataType::INT8;
|
||||||
|
} else if (dtype == "INT32") {
|
||||||
|
return fastdeploy::FDDataType::INT32;
|
||||||
|
} else if (dtype == "INT64") {
|
||||||
|
return fastdeploy::FDDataType::INT64;
|
||||||
|
} else if (dtype == "FP16") {
|
||||||
|
return fastdeploy::FDDataType::FP16;
|
||||||
|
} else if (dtype == "FP32") {
|
||||||
|
return fastdeploy::FDDataType::FP32;
|
||||||
|
}
|
||||||
|
return fastdeploy::FDDataType::UNKNOWN1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type) {
|
||||||
|
if (data_type == fastdeploy::FDDataType::UINT8) {
|
||||||
|
return "TYPE_UINT8";
|
||||||
|
} else if (data_type == fastdeploy::FDDataType::INT8) {
|
||||||
|
return "TYPE_INT8";
|
||||||
|
} else if (data_type == fastdeploy::FDDataType::INT32) {
|
||||||
|
return "TYPE_INT32";
|
||||||
|
} else if (data_type == fastdeploy::FDDataType::INT64) {
|
||||||
|
return "TYPE_INT64";
|
||||||
|
} else if (data_type == fastdeploy::FDDataType::FP16) {
|
||||||
|
return "TYPE_FP16";
|
||||||
|
} else if (data_type == fastdeploy::FDDataType::FP32) {
|
||||||
|
return "TYPE_FP32";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "TYPE_INVALID";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fastdeploy_runtime
|
||||||
|
} // namespace backend
|
||||||
|
} // namespace triton
|
72
serving/src/fastdeploy_backend_utils.h
Normal file
72
serving/src/fastdeploy_backend_utils.h
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
|
||||||
|
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions
|
||||||
|
// are met:
|
||||||
|
// * Redistributions of source code must retain the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer.
|
||||||
|
// * Redistributions in binary form must reproduce the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer in the
|
||||||
|
// documentation and/or other materials provided with the distribution.
|
||||||
|
// * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
// contributors may be used to endorse or promote products derived
|
||||||
|
// from this software without specific prior written permission.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "fastdeploy/core/fd_type.h"
|
||||||
|
#include "triton/core/tritonserver.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace backend {
|
||||||
|
namespace fastdeploy_runtime {
|
||||||
|
|
||||||
|
#define RESPOND_ALL_AND_SET_TRUE_IF_ERROR(RESPONSES, RESPONSES_COUNT, BOOL, X) \
|
||||||
|
do { \
|
||||||
|
TRITONSERVER_Error* raasnie_err__ = (X); \
|
||||||
|
if (raasnie_err__ != nullptr) { \
|
||||||
|
BOOL = true; \
|
||||||
|
for (size_t ridx = 0; ridx < RESPONSES_COUNT; ++ridx) { \
|
||||||
|
if (RESPONSES[ridx] != nullptr) { \
|
||||||
|
LOG_IF_ERROR( \
|
||||||
|
TRITONBACKEND_ResponseSend(RESPONSES[ridx], \
|
||||||
|
TRITONSERVER_RESPONSE_COMPLETE_FINAL, \
|
||||||
|
raasnie_err__), \
|
||||||
|
"failed to send error response"); \
|
||||||
|
RESPONSES[ridx] = nullptr; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
TRITONSERVER_ErrorDelete(raasnie_err__); \
|
||||||
|
} \
|
||||||
|
} while (false)
|
||||||
|
|
||||||
|
fastdeploy::FDDataType ConvertDataTypeToFD(TRITONSERVER_DataType dtype);
|
||||||
|
|
||||||
|
TRITONSERVER_DataType ConvertFDType(fastdeploy::FDDataType dtype);
|
||||||
|
|
||||||
|
fastdeploy::FDDataType ModelConfigDataTypeToFDType(
|
||||||
|
const std::string& data_type_str);
|
||||||
|
|
||||||
|
std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type);
|
||||||
|
|
||||||
|
} // namespace fastdeploy_runtime
|
||||||
|
} // namespace backend
|
||||||
|
} // namespace triton
|
1269
serving/src/fastdeploy_runtime.cc
Normal file
1269
serving/src/fastdeploy_runtime.cc
Normal file
File diff suppressed because it is too large
Load Diff
30
serving/src/libtriton_fastdeploy.ldscript
Normal file
30
serving/src/libtriton_fastdeploy.ldscript
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
{
|
||||||
|
global:
|
||||||
|
TRITONBACKEND_*;
|
||||||
|
local: *;
|
||||||
|
};
|
Reference in New Issue
Block a user