diff --git a/examples/vision/detection/yolov5/serving/README.md b/examples/vision/detection/yolov5/serving/README.md new file mode 100644 index 000000000..827d4fe59 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/README.md @@ -0,0 +1,19 @@ +# YOLOv5 Serving部署示例 + +```bash +#下载yolov5模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# 将模型放入 models/infer/1目录下, 并重命名为model.onnx +mv yolov5s.onnx models/infer/1/ + +# 拉取fastdeploy镜像 +docker pull xxx + +# 启动镜像和服务 +docker run xx + +# 客户端请求 +python yolov5_grpc_client.py +``` diff --git a/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py b/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py new file mode 100644 index 000000000..30a744b68 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import time + +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("postprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("postprocess output names:", self.output_names) + + def yolov5_postprocess(self, infer_outputs, im_infos): + """ + Parameters + ---------- + infer_outputs : numpy.array + Contains the batch of inference results + im_infos : numpy.array(b'{}') + Returns + ------- + numpy.array + yolov5 postprocess result + """ + results = [] + for i_batch in range(len(im_infos)): + new_infer_output = infer_outputs[i_batch:i_batch + 1] + new_im_info = im_infos[i_batch].decode('utf-8').replace("'", '"') + new_im_info = json.loads(new_im_info) + + result = fd.vision.detection.YOLOv5.postprocess( + [new_infer_output, ], new_im_info) + + r_str = fd.vision.utils.fd_result_to_json(result) + results.append(r_str) + return np.array(results, dtype=np.object) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + # print("num:", len(requests), flush=True) + for request in requests: + infer_outputs = pb_utils.get_input_tensor_by_name( + request, self.input_names[0]) + im_infos = pb_utils.get_input_tensor_by_name(request, + self.input_names[1]) + infer_outputs = infer_outputs.as_numpy() + im_infos = im_infos.as_numpy() + + results = self.yolov5_postprocess(infer_outputs, im_infos) + + out_tensor = pb_utils.Tensor(self.output_names[0], results) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor, ]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/detection/yolov5/serving/models/postprocess/config.pbtxt b/examples/vision/detection/yolov5/serving/models/postprocess/config.pbtxt new file mode 100644 index 000000000..129c979fa --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/postprocess/config.pbtxt @@ -0,0 +1,30 @@ +name: "postprocess" +backend: "python" + +input [ + { + name: "POST_INPUT_0" + data_type: TYPE_FP32 + dims: [ -1, -1, -1] + }, + { + name: "POST_INPUT_1" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +output [ + { + name: "POST_OUTPUT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py b/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py new file mode 100644 index 000000000..cd22aa37b --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import time + +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("preprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("preprocess output names:", self.output_names) + + def yolov5_preprocess(self, input_data): + """ + According to Triton input, the preprocessing results of YoloV5 model are obtained. + """ + im_infos = [] + pre_outputs = [] + for i_batch in input_data: + pre_output, im_info = fd.vision.detection.YOLOv5.preprocess( + i_batch) + pre_outputs.append(pre_output) + im_infos.append(im_info) + im_infos = np.array(im_infos, dtype=np.object) + pre_outputs = np.concatenate(pre_outputs, axis=0) + return pre_outputs, im_infos + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + # print("num:", len(requests), flush=True) + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + outputs = self.yolov5_preprocess(data) + output_tensors = [] + for idx, output in enumerate(outputs): + output_tensors.append( + pb_utils.Tensor(self.output_names[idx], output)) + inference_response = pb_utils.InferenceResponse( + output_tensors=output_tensors) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/detection/yolov5/serving/models/preprocess/config.pbtxt b/examples/vision/detection/yolov5/serving/models/preprocess/config.pbtxt new file mode 100644 index 000000000..cb56dcd68 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/preprocess/config.pbtxt @@ -0,0 +1,31 @@ +name: "preprocess" +backend: "python" +max_batch_size: 1 + +input [ + { + name: "INPUT_0" + data_type: TYPE_UINT8 + dims: [ -1, -1, 3 ] + } +] + +output [ + { + name: "preprocess_output_0" + data_type: TYPE_FP32 + dims: [ 3, -1, -1 ] + }, + { + name: "preprocess_output_1" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/examples/vision/detection/yolov5/serving/models/runtime/1/README.md b/examples/vision/detection/yolov5/serving/models/runtime/1/README.md new file mode 100644 index 000000000..ee05e496d --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/runtime/1/README.md @@ -0,0 +1,3 @@ +# Runtime Directory + +导出的部署模型需要放在本目录下 diff --git a/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt b/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt new file mode 100644 index 000000000..ffed1edf4 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt @@ -0,0 +1,38 @@ +# optional, If name is specified it must match the name of the model repository directory containing the model. +name: "runtime" +backend: "fastdeploy" +max_batch_size: 16 + +# Input configuration of the model +input [ + # 第一个输入 + { + # input name + name: "images" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + dims: [ 3, -1, -1 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "output" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] diff --git a/examples/vision/detection/yolov5/serving/models/yolov5/1/README.md b/examples/vision/detection/yolov5/serving/models/yolov5/1/README.md new file mode 100644 index 000000000..40da59aee --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/yolov5/1/README.md @@ -0,0 +1,3 @@ +# YOLOV5 Pipeline + +The pipeline directory does not have model files, but a version number directory needs to be maintained. diff --git a/examples/vision/detection/yolov5/serving/models/yolov5/config.pbtxt b/examples/vision/detection/yolov5/serving/models/yolov5/config.pbtxt new file mode 100644 index 000000000..9b8a39024 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/models/yolov5/config.pbtxt @@ -0,0 +1,65 @@ +name: "yolov5" +platform: "ensemble" +max_batch_size: 1 +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [ -1, -1, 3 ] + } +] +output [ + { + name: "detction_result" + data_type: TYPE_STRING + dims: [ -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "INPUT_0" + value: "INPUT" + } + output_map { + key: "preprocess_output_0" + value: "infer_input" + } + output_map { + key: "preprocess_output_1" + value: "postprocess_input_1" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "images" + value: "infer_input" + } + output_map { + key: "output" + value: "infer_output" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "POST_INPUT_0" + value: "infer_output" + } + input_map { + key: "POST_INPUT_1" + value: "postprocess_input_1" + } + output_map { + key: "POST_OUTPUT" + value: "detction_result" + } + } + ] +} \ No newline at end of file diff --git a/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py b/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py new file mode 100644 index 000000000..f21991174 --- /dev/null +++ b/examples/vision/detection/yolov5/serving/yolov5_grpc_client.py @@ -0,0 +1,112 @@ +import logging +import numpy as np +import time +from typing import Optional +import cv2 +import json + +from tritonclient import utils as client_utils +from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2 + +LOGGER = logging.getLogger("run_inference_on_triton") + + +class SyncGRPCTritonRunner: + DEFAULT_MAX_RESP_WAIT_S = 120 + + def __init__( + self, + server_url: str, + model_name: str, + model_version: str, + *, + verbose=False, + resp_wait_s: Optional[float]=None, ): + self._server_url = server_url + self._model_name = model_name + self._model_version = model_version + self._verbose = verbose + self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s + + self._client = InferenceServerClient( + self._server_url, verbose=self._verbose) + error = self._verify_triton_state(self._client) + if error: + raise RuntimeError( + f"Could not communicate to Triton Server: {error}") + + LOGGER.debug( + f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " + f"are up and ready!") + + model_config = self._client.get_model_config(self._model_name, + self._model_version) + model_metadata = self._client.get_model_metadata(self._model_name, + self._model_version) + LOGGER.info(f"Model config {model_config}") + LOGGER.info(f"Model metadata {model_metadata}") + + for tm in model_metadata.inputs: + print("tm:", tm) + self._inputs = {tm.name: tm for tm in model_metadata.inputs} + self._input_names = list(self._inputs) + self._outputs = {tm.name: tm for tm in model_metadata.outputs} + self._output_names = list(self._outputs) + self._outputs_req = [ + InferRequestedOutput(name) for name in self._outputs + ] + + def Run(self, inputs): + """ + Args: + inputs: list, Each value corresponds to an input name of self._input_names + Returns: + results: dict, {name : numpy.array} + """ + infer_inputs = [] + for idx, data in enumerate(inputs): + print("len(data):", len(data)) + print("name:", self._input_names[idx], " shape:", data.shape, + data.dtype) + #data = np.array([[x.encode('utf-8')] for x in data], + # dtype=np.object_) + infer_input = InferInput(self._input_names[idx], data.shape, + "UINT8") + infer_input.set_data_from_numpy(data) + infer_inputs.append(infer_input) + + results = self._client.infer( + model_name=self._model_name, + model_version=self._model_version, + inputs=infer_inputs, + outputs=self._outputs_req, + client_timeout=self._response_wait_t, ) + results = {name: results.as_numpy(name) for name in self._output_names} + return results + + def _verify_triton_state(self, triton_client): + if not triton_client.is_server_live(): + return f"Triton server {self._server_url} is not live" + elif not triton_client.is_server_ready(): + return f"Triton server {self._server_url} is not ready" + elif not triton_client.is_model_ready(self._model_name, + self._model_version): + return f"Model {self._model_name}:{self._model_version} is not ready" + return None + + +if __name__ == "__main__": + model_name = "yolov5" + model_version = "1" + url = "localhost:8001" + runner = SyncGRPCTritonRunner(url, model_name, model_version) + im = cv2.imread("000000014439.jpg") + im = np.array([im, ]) + for i in range(1): + result = runner.Run([im, ]) + for name, values in result.items(): + print("output_name:", name) + for i in range(len(values)): + value = values[i][0] + value = json.loads(value) + print(value) diff --git a/python/fastdeploy/vision/__init__.py b/python/fastdeploy/vision/__init__.py index d774fe783..86f4f978f 100644 --- a/python/fastdeploy/vision/__init__.py +++ b/python/fastdeploy/vision/__init__.py @@ -22,4 +22,5 @@ from . import facedet from . import faceid from . import ocr from . import evaluation +from .utils import fd_result_to_json from .visualize import * diff --git a/python/fastdeploy/vision/utils.py b/python/fastdeploy/vision/utils.py new file mode 100644 index 000000000..e9c5119d4 --- /dev/null +++ b/python/fastdeploy/vision/utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +import json +from .. import c_lib_wrap as C + + +def mask_to_json(result): + r_json = { + "data": result.data, + "shape": result.shape, + } + return json.dumps(r_json) + + +def detection_to_json(result): + masks = [] + for mask in result.masks: + masks.append(mask_to_json(mask)) + r_json = { + "boxes": result.boxes, + "scores": result.scores, + "label_ids": result.label_ids, + "masks": masks, + "contain_masks": result.contain_masks + } + return json.dumps(r_json) + + +def fd_result_to_json(result): + if isinstance(result, C.vision.DetectionResult): + return detection_to_json(result) + elif isinstance(result, C.vision.Mask): + return mask_to_json(result) + else: + assert False, "{} Conversion to JSON format is not supported".format( + type(result)) + return {} diff --git a/serving/CMakeLists.txt b/serving/CMakeLists.txt new file mode 100644 index 000000000..d74940234 --- /dev/null +++ b/serving/CMakeLists.txt @@ -0,0 +1,109 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17) + +project(trironpaddlebackend LANGUAGES C CXX) + +set(FASTDEPLOY_DIR "" CACHE PATH "Paths to FastDeploy Directory. Multiple paths may be specified by sparating them with a semicolon.") +set(FASTDEPLOY_INCLUDE_PATHS "${FASTDEPLOY_DIR}/include" + CACHE PATH "Paths to FastDeploy includes. Multiple paths may be specified by sparating them with a semicolon.") +set(FASTDEPLOY_LIB_PATHS "${FASTDEPLOY_DIR}/lib" + CACHE PATH "Paths to FastDeploy libraries. Multiple paths may be specified by sparating them with a semicolon.") +set(FASTDEPLOY_LIB_NAME "fastdeploy_runtime") + +set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") + +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +configure_file(src/libtriton_fastdeploy.ldscript libtriton_fastdeploy.ldscript COPYONLY) + +add_library( + triton-fastdeploy-backend SHARED + src/fastdeploy_runtime.cc + src/fastdeploy_backend_utils.cc +) + +target_include_directories( + triton-fastdeploy-backend + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src +) + +target_include_directories( + triton-fastdeploy-backend + PRIVATE ${FASTDEPLOY_INCLUDE_PATHS} +) + +target_link_libraries( + triton-fastdeploy-backend + PRIVATE "-L${FASTDEPLOY_LIB_PATHS} -l${FASTDEPLOY_LIB_NAME}" +) + +target_compile_features(triton-fastdeploy-backend PRIVATE cxx_std_11) +target_compile_options( + triton-fastdeploy-backend PRIVATE + $<$,$,$>: + -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> +) + +set_target_properties( + triton-fastdeploy-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_fastdeploy + SKIP_BUILD_RPATH TRUE + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastdeploy.ldscript + LINK_FLAGS "-Wl,--version-script libtriton_fastdeploy.ldscript" +) + +target_link_libraries( + triton-fastdeploy-backend + PRIVATE + triton-backend-utils # from repo-backend + triton-core-serverstub # from repo-core +) diff --git a/serving/src/fastdeploy_backend_utils.cc b/serving/src/fastdeploy_backend_utils.cc new file mode 100644 index 000000000..2de0baf37 --- /dev/null +++ b/serving/src/fastdeploy_backend_utils.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "fastdeploy_backend_utils.h" + +#include +#include +#include +#include +#include + +namespace triton { +namespace backend { +namespace fastdeploy_runtime { + +TRITONSERVER_DataType ConvertFDType(fastdeploy::FDDataType dtype) { + switch (dtype) { + case fastdeploy::FDDataType::UNKNOWN1: + return TRITONSERVER_TYPE_INVALID; + case ::fastdeploy::FDDataType::UINT8: + return TRITONSERVER_TYPE_UINT8; + case ::fastdeploy::FDDataType::INT8: + return TRITONSERVER_TYPE_INT8; + case ::fastdeploy::FDDataType::INT32: + return TRITONSERVER_TYPE_INT32; + case ::fastdeploy::FDDataType::INT64: + return TRITONSERVER_TYPE_INT64; + case ::fastdeploy::FDDataType::FP32: + return TRITONSERVER_TYPE_FP32; + case ::fastdeploy::FDDataType::FP16: + return TRITONSERVER_TYPE_FP16; + default: + break; + } + return TRITONSERVER_TYPE_INVALID; +} + +fastdeploy::FDDataType ConvertDataTypeToFD(TRITONSERVER_DataType dtype) { + switch (dtype) { + case TRITONSERVER_TYPE_INVALID: + return ::fastdeploy::FDDataType::UNKNOWN1; + case TRITONSERVER_TYPE_UINT8: + return ::fastdeploy::FDDataType::UINT8; + case TRITONSERVER_TYPE_INT8: + return ::fastdeploy::FDDataType::INT8; + case TRITONSERVER_TYPE_INT32: + return ::fastdeploy::FDDataType::INT32; + case TRITONSERVER_TYPE_INT64: + return ::fastdeploy::FDDataType::INT64; + case TRITONSERVER_TYPE_FP32: + return ::fastdeploy::FDDataType::FP32; + case TRITONSERVER_TYPE_FP16: + return ::fastdeploy::FDDataType::FP16; + default: + break; + } + return ::fastdeploy::FDDataType::UNKNOWN1; +} + +fastdeploy::FDDataType ModelConfigDataTypeToFDType( + const std::string& data_type_str) { + // Must start with "TYPE_". + if (data_type_str.rfind("TYPE_", 0) != 0) { + return fastdeploy::FDDataType::UNKNOWN1; + } + + const std::string dtype = data_type_str.substr(strlen("TYPE_")); + + if (dtype == "UINT8") { + return fastdeploy::FDDataType::UINT8; + } else if (dtype == "INT8") { + return fastdeploy::FDDataType::INT8; + } else if (dtype == "INT32") { + return fastdeploy::FDDataType::INT32; + } else if (dtype == "INT64") { + return fastdeploy::FDDataType::INT64; + } else if (dtype == "FP16") { + return fastdeploy::FDDataType::FP16; + } else if (dtype == "FP32") { + return fastdeploy::FDDataType::FP32; + } + return fastdeploy::FDDataType::UNKNOWN1; +} + +std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type) { + if (data_type == fastdeploy::FDDataType::UINT8) { + return "TYPE_UINT8"; + } else if (data_type == fastdeploy::FDDataType::INT8) { + return "TYPE_INT8"; + } else if (data_type == fastdeploy::FDDataType::INT32) { + return "TYPE_INT32"; + } else if (data_type == fastdeploy::FDDataType::INT64) { + return "TYPE_INT64"; + } else if (data_type == fastdeploy::FDDataType::FP16) { + return "TYPE_FP16"; + } else if (data_type == fastdeploy::FDDataType::FP32) { + return "TYPE_FP32"; + } + + return "TYPE_INVALID"; +} + +} // namespace fastdeploy_runtime +} // namespace backend +} // namespace triton \ No newline at end of file diff --git a/serving/src/fastdeploy_backend_utils.h b/serving/src/fastdeploy_backend_utils.h new file mode 100644 index 000000000..2a7cdd100 --- /dev/null +++ b/serving/src/fastdeploy_backend_utils.h @@ -0,0 +1,72 @@ + +// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +#include "fastdeploy/core/fd_type.h" +#include "triton/core/tritonserver.h" + +namespace triton { +namespace backend { +namespace fastdeploy_runtime { + +#define RESPOND_ALL_AND_SET_TRUE_IF_ERROR(RESPONSES, RESPONSES_COUNT, BOOL, X) \ + do { \ + TRITONSERVER_Error* raasnie_err__ = (X); \ + if (raasnie_err__ != nullptr) { \ + BOOL = true; \ + for (size_t ridx = 0; ridx < RESPONSES_COUNT; ++ridx) { \ + if (RESPONSES[ridx] != nullptr) { \ + LOG_IF_ERROR( \ + TRITONBACKEND_ResponseSend(RESPONSES[ridx], \ + TRITONSERVER_RESPONSE_COMPLETE_FINAL, \ + raasnie_err__), \ + "failed to send error response"); \ + RESPONSES[ridx] = nullptr; \ + } \ + } \ + TRITONSERVER_ErrorDelete(raasnie_err__); \ + } \ + } while (false) + +fastdeploy::FDDataType ConvertDataTypeToFD(TRITONSERVER_DataType dtype); + +TRITONSERVER_DataType ConvertFDType(fastdeploy::FDDataType dtype); + +fastdeploy::FDDataType ModelConfigDataTypeToFDType( + const std::string& data_type_str); + +std::string FDTypeToModelConfigDataType(fastdeploy::FDDataType data_type); + +} // namespace fastdeploy_runtime +} // namespace backend +} // namespace triton diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc new file mode 100644 index 000000000..c918d1e45 --- /dev/null +++ b/serving/src/fastdeploy_runtime.cc @@ -0,0 +1,1269 @@ +// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include +#include + +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/core/fd_type.h" +#include "fastdeploy/runtime.h" +#include "fastdeploy/utils/utils.h" +#include "fastdeploy_backend_utils.h" +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_memory.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" + +#ifdef TRITON_ENABLE_GPU +#include +#endif // TRITON_ENABLE_GPU + +// +// FastDeploy Backend that implements the TRITONBACKEND API. +// +namespace triton { +namespace backend { +namespace fastdeploy_runtime { + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, + ModelState** state); + virtual ~ModelState() = default; + + // Load an model. If 'instance_group_kind' is not + // TRITONSERVER_INSTANCEGROUPKIND_AUTO then use it and + // 'instance_group_device_id' to initialize the appropriate + // execution providers. Return in 'model_path' the full path to the + // onnx or paddle file. + TRITONSERVER_Error* LoadModel( + const std::string& artifact_name, + const TRITONSERVER_InstanceGroupKind instance_group_kind, + const int32_t instance_group_device_id, std::string* model_path, + std::string* params_path, fastdeploy::Runtime** runtime, + cudaStream_t stream); + + const std::map>& ModelOutputs() { + return model_outputs_; + } + + private: + ModelState(TRITONBACKEND_Model* triton_model); + TRITONSERVER_Error* AutoCompleteConfig(); + + TRITONSERVER_Error* AutoCompleteIO( + const char* key, const std::vector& io_infos); + + // Runtime options used when creating a FastDeploy Runtime. + std::unique_ptr runtime_options_; + + // model_outputs is a map that contains unique outputs that the model must + // provide. In the model configuration, the output in the state configuration + // can have intersection with the outputs section of the model. If an output + // is specified both in the output section and state section, it indicates + // that the backend must return the output state to the client too. + std::map> model_outputs_; +}; + +TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, + ModelState** state) { + try { + *state = new ModelState(triton_model); + } catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + // Auto-complete the configuration if requested... + bool auto_complete_config = false; + RETURN_IF_ERROR(TRITONBACKEND_ModelAutoCompleteConfig(triton_model, + &auto_complete_config)); + if (auto_complete_config) { + RETURN_IF_ERROR((*state)->AutoCompleteConfig()); + // RETURN_IF_ERROR((*state)->SetModelConfig()); + } + + auto& model_outputs = (*state)->model_outputs_; + + // Parse the output states in the model configuration + triton::common::TritonJson::Value sequence_batching; + if ((*state)->ModelConfig().Find("sequence_batching", &sequence_batching)) { + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + for (size_t i = 0; i < states.ArraySize(); i++) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + std::string output_state_name; + RETURN_IF_ERROR( + state.MemberAsString("output_name", &output_state_name)); + auto it = model_outputs.find(output_state_name); + if (it == model_outputs.end()) { + model_outputs.insert({output_state_name, std::make_pair(-1, i)}); + } else { + it->second.second = i; + } + } + } + } + + // Parse the output names in the model configuration + triton::common::TritonJson::Value outputs; + RETURN_IF_ERROR((*state)->ModelConfig().MemberAsArray("output", &outputs)); + for (size_t i = 0; i < outputs.ArraySize(); i++) { + triton::common::TritonJson::Value output; + RETURN_IF_ERROR(outputs.IndexAsObject(i, &output)); + + std::string output_name_str; + + RETURN_IF_ERROR(output.MemberAsString("name", &output_name_str)); + auto it = model_outputs.find(output_name_str); + if (it == model_outputs.end()) { + model_outputs.insert({output_name_str, {i, -1}}); + } else { + it->second.first = i; + } + } + + return nullptr; // success +} + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) + : BackendModel(triton_model) { + // Create runtime options that will be cloned and used for each + // instance when creating that instance's runtime. + runtime_options_.reset(new fastdeploy::RuntimeOption()); + + { + triton::common::TritonJson::Value optimization; + if (ModelConfig().Find("optimization", &optimization)) { + triton::common::TritonJson::Value backend; + if (optimization.Find("onnxruntime", &backend)) { + runtime_options_->UseOrtBackend(); + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; + if (param_key == "graph_level") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + value_string, &runtime_options_->ort_graph_opt_level)); + } else if (param_key == "inter_op_num_threads") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + value_string, &runtime_options_->ort_inter_op_num_threads)); + } else if (param_key == "execution_mode") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + value_string, &runtime_options_->ort_execution_mode)); + } + } + } else if (optimization.Find("tensorrt", &backend)) { + runtime_options_->UseTrtBackend(); + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; + if (param_key == "cpu_threads") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); + } + // TODO(liqi): add tensorrt + } + } else if (optimization.Find("paddle", &backend)) { + runtime_options_->UsePaddleBackend(); + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; + if (param_key == "cpu_threads") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); + } else if (param_key == "capacity") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( + value_string, &runtime_options_->pd_mkldnn_cache_size)); + } else if (param_key == "use_mkldnn") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR(ParseBoolValue( + value_string, &runtime_options_->pd_enable_mkldnn)); + } + } + } else if (optimization.Find("openvino", &backend)) { + runtime_options_->UseOpenVINOBackend(); + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(backend.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; + if (param_key == "cpu_threads") { + THROW_IF_BACKEND_MODEL_ERROR( + backend.MemberAsString(param_key.c_str(), &value_string)); + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &runtime_options_->cpu_thread_num)); + } + // TODO(liqi): add openvino + } + } + } + } +} + +TRITONSERVER_Error* ModelState::LoadModel( + const std::string& artifact_name, + const TRITONSERVER_InstanceGroupKind instance_group_kind, + const int32_t instance_group_device_id, std::string* model_path, + std::string* params_path, fastdeploy::Runtime** runtime, + cudaStream_t stream) { + auto dir_path = JoinPath({RepositoryPath(), std::to_string(Version())}); + { + // ONNX Format + bool exists; + *model_path = JoinPath({dir_path, "model.onnx"}); + RETURN_IF_ERROR(FileExists(*model_path, &exists)); + + // Paddle Formax + if (not exists) { + *model_path = JoinPath({dir_path, "model.pdmodel"}); + RETURN_IF_ERROR(FileExists(*model_path, &exists)); + if (not exists) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Model should be named as 'model.onnx' or 'model.pdmodel'") + .c_str()); + } + *params_path = JoinPath({dir_path, "model.pdiparams"}); + RETURN_IF_ERROR(FileExists(*params_path, &exists)); + if (not exists) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string("Paddle params should be named as 'model.pdiparams' or " + "not provided.'") + .c_str()); + } + runtime_options_->model_format = fastdeploy::Frontend::PADDLE; + runtime_options_->model_file = *model_path; + runtime_options_->params_file = *params_path; + } else { + runtime_options_->model_format = fastdeploy::Frontend::ONNX; + runtime_options_->model_file = *model_path; + } + } + + // GPU +#ifdef TRITON_ENABLE_GPU + if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || + (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { + runtime_options_->UseGpu(instance_group_device_id); + } else { + runtime_options_->UseCpu(); + } +#else + runtime_options_->UseCpu(); +#endif // TRITON_ENABLE_GPU + + *runtime = new fastdeploy::Runtime(); + if (!(*runtime)->Init(*runtime_options_)) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, + std::string("Runtime init error").c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* ModelState::AutoCompleteConfig() { + // If the model configuration already specifies inputs and outputs + // then don't perform any auto-completion. + size_t input_cnt = 0; + size_t output_cnt = 0; + { + triton::common::TritonJson::Value inputs; + if (ModelConfig().Find("input", &inputs)) { + input_cnt = inputs.ArraySize(); + } + + triton::common::TritonJson::Value config_batch_inputs; + if (ModelConfig().Find("batch_input", &config_batch_inputs)) { + input_cnt += config_batch_inputs.ArraySize(); + } + + triton::common::TritonJson::Value outputs; + if (ModelConfig().Find("output", &outputs)) { + output_cnt = outputs.ArraySize(); + } + } + + if ((input_cnt > 0) && (output_cnt > 0)) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("skipping model configuration auto-complete for '") + + Name() + "': inputs and outputs already specified") + .c_str()); + return nullptr; // success + } + + std::string artifact_name; + RETURN_IF_ERROR( + ModelConfig().MemberAsString("default_model_filename", &artifact_name)); + + std::string model_path; + std::string params_path; + + TRITONSERVER_InstanceGroupKind kind = TRITONSERVER_INSTANCEGROUPKIND_CPU; + +#ifdef TRITON_ENABLE_GPU + triton::common::TritonJson::Value instance_group; + ModelConfig().Find("instance_group", &instance_group); + + // Earlier in the model lifecycle, device checks for the instance group + // have already occurred. If at least one instance group with + // "kind" = "KIND_GPU" then allow model to use GPU else autocomplete to + // "KIND_CPU" + for (size_t i = 0; i < instance_group.ArraySize(); ++i) { + triton::common::TritonJson::Value instance_obj; + instance_group.IndexAsObject(i, &instance_obj); + + triton::common::TritonJson::Value instance_group_kind; + instance_obj.Find("kind", &instance_group_kind); + std::string kind_str; + RETURN_IF_ERROR(instance_group_kind.AsString(&kind_str)); + + if (kind_str == "KIND_GPU") { + kind = TRITONSERVER_INSTANCEGROUPKIND_GPU; + break; + } + } +#endif // TRITON_ENABLE_GPU + + fastdeploy::Runtime* runtime = nullptr; + RETURN_IF_ERROR(LoadModel(artifact_name, kind, 0, &model_path, ¶ms_path, + &runtime, nullptr)); + + // TODO(liqi): need to infer max_batch_size + int max_batch_size = -1; + triton::common::TritonJson::Value mbs_value; + ModelConfig().Find("max_batch_size", &mbs_value); + mbs_value.SetInt(max_batch_size); + SetMaxBatchSize(max_batch_size); + + auto input_infos = runtime->GetInputInfos(); + auto output_infos = runtime->GetOutputInfos(); + if (input_cnt == 0) { + RETURN_IF_ERROR(AutoCompleteIO("input", input_infos)); + } + if (output_cnt == 0) { + RETURN_IF_ERROR(AutoCompleteIO("output", output_infos)); + } + + if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { + triton::common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERROR(ModelConfig().PrettyWrite(&buffer)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("post auto-complete:\n") + buffer.Contents()).c_str()); + } + + return nullptr; // success +} + +TRITONSERVER_Error* ModelState::AutoCompleteIO( + const char* key, const std::vector& io_infos) { + triton::common::TritonJson::Value existing_ios; + bool found_ios = ModelConfig().Find(key, &existing_ios); + + triton::common::TritonJson::Value ios( + ModelConfig(), triton::common::TritonJson::ValueType::ARRAY); + for (const auto& io_info : io_infos) { + triton::common::TritonJson::Value io( + ModelConfig(), triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_ERROR(io.AddString("name", io_info.name)); + RETURN_IF_ERROR( + io.AddString("data_type", FDTypeToModelConfigDataType(io_info.dtype))); + + // The model signature supports batching then the first dimension + // is -1 and should not appear in the model configuration 'dims' + // that we are creating. + const auto& io_info_shape = io_info.shape; + triton::common::TritonJson::Value dims( + ModelConfig(), triton::common::TritonJson::ValueType::ARRAY); + for (size_t i = (MaxBatchSize() > 0) ? 1 : 0; i < io_info_shape.size(); + ++i) { + RETURN_IF_ERROR(dims.AppendInt(io_info_shape[i])); + } + + // If dims are empty then must use a reshape... + if (dims.ArraySize() == 0) { + RETURN_IF_ERROR(dims.AppendInt(1)); + triton::common::TritonJson::Value reshape( + ModelConfig(), triton::common::TritonJson::ValueType::OBJECT); + triton::common::TritonJson::Value reshape_dims( + ModelConfig(), triton::common::TritonJson::ValueType::ARRAY); + RETURN_IF_ERROR(reshape.Add("shape", std::move(reshape_dims))); + RETURN_IF_ERROR(io.Add("reshape", std::move(reshape))); + } + RETURN_IF_ERROR(io.Add("dims", std::move(dims))); + RETURN_IF_ERROR(ios.Append(std::move(io))); + } + + if (found_ios) { + existing_ios.Swap(ios); + } else { + ModelConfig().Add(key, std::move(ios)); + } + + return nullptr; // success +} + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each TRITONBACKEND_ModelInstance. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState(); + + void ReleaseRunResources(); + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + // Execute... + void ProcessRequests(TRITONBACKEND_Request** requests, + const uint32_t request_count); + + private: + ModelInstanceState(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance); + void ReleaseOrtRunResources(); + int GetInfoIndex(const std::string& name, + const std::vector& infos); + void GetInfoNames(const std::vector& infos, + std::vector& names); + TRITONSERVER_Error* ValidateInputs(); + TRITONSERVER_Error* ValidateOutputs(); + TRITONSERVER_Error* Run(std::vector* responses, + const uint32_t response_count); + TRITONSERVER_Error* SetInputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses, + BackendInputCollector* collector, bool* cuda_copy); + + TRITONSERVER_Error* ReadOutputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses); + + ModelState* model_state_; + + // The full path to the model file. + std::string model_path_; + std::string params_path_; + + std::shared_ptr runtime_; + + std::vector input_names_; + std::vector output_names_; + std::vector input_tensor_infos_; + std::vector output_tensor_infos_; + + std::vector input_tensors_; + std::vector output_tensors_; +}; + +TRITONSERVER_Error* ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) { + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + return nullptr; // success +} + +ModelInstanceState::ModelInstanceState( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state), + runtime_(nullptr) { + fastdeploy::Runtime* runtime = nullptr; + THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( + ArtifactFilename(), Kind(), DeviceId(), &model_path_, ¶ms_path_, + &runtime, CudaStream())); + runtime_.reset(runtime); + runtime = nullptr; + + THROW_IF_BACKEND_INSTANCE_ERROR(ValidateInputs()); + THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); +} + +ModelInstanceState::~ModelInstanceState() { ReleaseRunResources(); } + +void ModelInstanceState::ReleaseRunResources() { + input_names_.clear(); + output_names_.clear(); + input_tensors_.clear(); + output_tensors_.clear(); + input_tensor_infos_.clear(); + output_tensor_infos_.clear(); +} + +int ModelInstanceState::GetInfoIndex( + const std::string& name, const std::vector& infos) { + for (size_t i = 0; i < infos.size(); ++i) { + if (name == infos[i].name) return int(i); + } + return -1; +} + +void ModelInstanceState::GetInfoNames( + const std::vector& infos, + std::vector& names) { + for (const auto& info : infos) names.emplace_back(info.name); +} + +TRITONSERVER_Error* ModelInstanceState::ValidateInputs() { + input_tensor_infos_ = runtime_->GetInputInfos(); + std::vector names; + GetInfoNames(input_tensor_infos_, names); + input_tensors_.clear(); + input_names_.clear(); + input_tensors_.reserve(input_tensor_infos_.size()); + + triton::common::TritonJson::Value ios; + RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios)); + if (input_tensor_infos_.size() != ios.ArraySize()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unable to load model '") + model_state_->Name() + + "', configuration expects " + std::to_string(ios.ArraySize()) + + " inputs, model provides " + + std::to_string(input_tensor_infos_.size())) + .c_str()); + } + for (size_t i = 0; i < ios.ArraySize(); i++) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + std::string io_dtype; + RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); + + input_names_.emplace_back(io_name); + int index = GetInfoIndex(io_name, input_tensor_infos_); + if (index < 0) { + std::set inames(names.begin(), names.end()); + RETURN_IF_ERROR(CheckAllowedModelInput(io, inames)); + } + input_tensors_.emplace_back(io_name); + + auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype); + if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unsupported datatype ") + io_dtype + " for input '" + + io_name + "' for model '" + model_state_->Name() + "'") + .c_str()); + } else if (fd_data_type != input_tensor_infos_[index].dtype) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unable to load model '") + model_state_->Name() + + "', configuration expects datatype " + io_dtype + " for input '" + + io_name + "', model provides TYPE_" + + TRITONSERVER_DataTypeString( + ConvertFDType(input_tensor_infos_[index].dtype))) + .c_str()); + } + + // If a reshape is provided for the input then use that when + // validating that the model matches what is expected. + std::vector dims; + triton::common::TritonJson::Value reshape; + if (io.Find("reshape", &reshape)) { + RETURN_IF_ERROR(ParseShape(reshape, "shape", &dims)); + } else { + RETURN_IF_ERROR(ParseShape(io, "dims", &dims)); + } + + triton::common::TritonJson::Value allow_ragged_batch_json; + bool allow_ragged_batch = false; + if (io.Find("allow_ragged_batch", &allow_ragged_batch_json)) { + RETURN_IF_ERROR(allow_ragged_batch_json.AsBool(&allow_ragged_batch)); + } + if (allow_ragged_batch) { + const std::vector model_shape( + input_tensor_infos_[index].shape.begin(), + input_tensor_infos_[index].shape.end()); + // Make sure the input has shpae [-1] + if ((model_shape.size() != 1) || (model_shape[0] != WILDCARD_DIM)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unable to load model '") + model_state_->Name() + + "', configuration expects model provides input with shape [-1] " + "for ragged input '" + + io_name + "', model provides " + ShapeToString(model_shape)) + .c_str()); + } + } else { + // TODO: Implement shape checking + // RETURN_IF_ERROR(CompareDimsSupported(); + } + } + return nullptr; // success +} + +TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { + output_tensor_infos_ = runtime_->GetOutputInfos(); + output_tensors_.clear(); + output_tensors_.reserve(output_tensor_infos_.size()); + std::set out_names; + for (const auto& info : output_tensor_infos_) { + output_tensors_.emplace_back(info.name); + out_names.insert(info.name); + } + output_names_.clear(); + + triton::common::TritonJson::Value ios; + RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); + // It is possible not to return all output! + // if (output_tensor_infos_.size() != ios.ArraySize()) { + // return TRITONSERVER_ErrorNew( + // TRITONSERVER_ERROR_INVALID_ARG, + // (std::string("unable to load model '") + model_state_->Name() + + // "', configuration expects " + std::to_string(ios.ArraySize()) + + // " outputs, model provides " + + // std::to_string(output_tensor_infos_.size())) + // .c_str()); + // } + for (size_t i = 0; i < ios.ArraySize(); i++) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(ios.IndexAsObject(i, &io)); + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + std::string io_dtype; + RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); + + output_names_.emplace_back(io_name); + int index = GetInfoIndex(io_name, output_tensor_infos_); + if (index < 0) { + RETURN_IF_ERROR(CheckAllowedModelInput(io, out_names)); + } + // output_tensors_.emplace_back(io_name); + + auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype); + if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unsupported datatype ") + io_dtype + " for output '" + + io_name + "' for model '" + model_state_->Name() + "'") + .c_str()); + } else if (fd_data_type != output_tensor_infos_[index].dtype) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unable to load model '") + model_state_->Name() + + "', configuration expects datatype " + io_dtype + " for output '" + + io_name + "', model provides TYPE_" + + TRITONSERVER_DataTypeString( + ConvertFDType(output_tensor_infos_[index].dtype))) + .c_str()); + } + + // If a reshape is provided for the input then use that when + // validating that the model matches what is expected. + std::vector dims; + triton::common::TritonJson::Value reshape; + if (io.Find("reshape", &reshape)) { + RETURN_IF_ERROR(ParseShape(reshape, "shape", &dims)); + } else { + RETURN_IF_ERROR(ParseShape(io, "dims", &dims)); + } + + // The batch output shape doesn't necessarily match the model + if (model_state_->FindBatchOutput(io_name) == nullptr) { + // TODO: Implement shape checking + // RETURN_IF_ERROR(CompareDimsSupported()); + } + } + return nullptr; // success +} + +void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests, + const uint32_t request_count) { + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() + + " with " + std::to_string(request_count) + " requests") + .c_str()); + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + const int max_batch_size = model_state_->MaxBatchSize(); + // For each request collect the total batch size for this inference + // execution. The batch-size, number of inputs, and size of each + // input has already been checked so don't need to do that here. + size_t total_batch_size = 0; + for (size_t i = 0; i < request_count; i++) { + // If we get a nullptr request then something is badly wrong. Fail + // and release all requests. + if (requests[i] == nullptr) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "null request given to FastDeploy Runtime backend for '" + + Name() + "'") + .c_str())); + return; + } + + if (max_batch_size > 0) { + // Retrieve the batch size from one of the inputs, if the model + // supports batching, the first dimension size is batch size + TRITONBACKEND_Input* input; + TRITONSERVER_Error* err = + TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); + if (err == nullptr) { + const int64_t* shape; + err = TRITONBACKEND_InputProperties(input, nullptr, nullptr, &shape, + nullptr, nullptr, nullptr); + total_batch_size += shape[0]; + } + if (err != nullptr) { + RequestsRespondWithError(requests, request_count, err); + return; + } + } else { + total_batch_size += 1; + } + } + + // If there are no valid payloads then no need to run the inference. + if (total_batch_size == 0) { + return; + } + + // Make sure the maximum batch size is not exceeded. The + // total_batch_size must be 1 for models that don't support batching + // (i.e. max_batch_size == 0). If max_batch_size is exceeded then + // scheduler has done something badly wrong so fail and release all + // requests. + if ((total_batch_size != 1) && (total_batch_size > (size_t)max_batch_size)) { + RequestsRespondWithError( + requests, request_count, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("batch size " + std::to_string(total_batch_size) + + " for '" + Name() + "', max allowed is " + + std::to_string(max_batch_size)) + .c_str())); + return; + } + + // At this point we are committed to running inference with all + // 'requests'. Create a response for each request. During input + // processing if there is an error with any request that error will + // be sent immediately with the corresponding response (and the + // response unique_ptr will then be nullptr). The request object + // itself will not be released until after all inferencing is done + // (below) as we may need to access the request object when + // determine how to process outputs (for example, even if we don't + // need the outputs for a request that has an error, we do need to + // know the size of those outputs associated with the request so we + // can skip them in the output tensors). + std::vector responses; + responses.reserve(request_count); + bool all_response_failed = false; + + for (size_t i = 0; i < request_count; i++) { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err == nullptr) { + responses.emplace_back(response); + } else { + responses.emplace_back(nullptr); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); + TRITONSERVER_ErrorDelete(err); + } + } + + bool cuda_copy = false; + BackendInputCollector collector( + requests, request_count, &responses, model_state_->TritonMemoryManager(), + model_state_->EnablePinnedInput(), CudaStream(), nullptr, nullptr, 0, + HostPolicyName().c_str()); + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + SetInputTensors(total_batch_size, requests, request_count, &responses, + &collector, &cuda_copy)); + + // Wait for any in-flight input tensor copies to complete. +#ifdef TRITON_ENABLE_GPU + if (cuda_copy) { + cudaStreamSynchronize(CudaStream()); + } +#endif + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + if (!all_response_failed) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, + all_response_failed, + Run(&responses, request_count)); + } + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + if (!all_response_failed) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + ReadOutputTensors(total_batch_size, requests, request_count, + &responses)); + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + + // Send all the responses that haven't already been sent because of + // an earlier error. Note that the responses are not set to nullptr + // here as we need that indication below to determine if the request + // we successful or not. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR(TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send fastdeploy backend response"); + } + } + + // Report statistics for each request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics( + TritonModelInstance(), request, + (responses[r] != nullptr) /* success */, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); + + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + + if (!all_response_failed) { + // Report the entire batch statistics. + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), total_batch_size, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); + } +} + +TRITONSERVER_Error* ModelInstanceState::Run( + std::vector* responses, + const uint32_t response_count) { + runtime_->Infer(input_tensors_, &output_tensors_); +#ifdef TRITON_ENABLE_GPU + if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + // TODO: stream controll + cudaDeviceSynchronize(); + // cudaStreamSynchronize(CudaStream()); + } +#endif + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::SetInputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses, + BackendInputCollector* collector, bool* cuda_copy) { + const int max_batch_size = model_state_->MaxBatchSize(); + // All requests must have equally-sized input tensors so use any + // request as the representative for the input tensors. + uint32_t input_count; + RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count)); + + for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { + TRITONBACKEND_Input* input; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); + + const char* input_name; + TRITONSERVER_DataType input_datatype; + const int64_t* input_shape; + uint32_t input_dims_count; + RETURN_IF_ERROR(TRITONBACKEND_InputProperties( + input, &input_name, &input_datatype, &input_shape, &input_dims_count, + nullptr, nullptr)); + + if (input_tensors_[input_idx].name != std::string(input_name)) { + auto err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("Input name [") + input_name + + std::string("] is not one of the FD predictor input: ") + + input_tensors_[input_idx].name) + .c_str()); + // SendErrorForResponses(responses, request_count, err); + return err; + } + + std::vector batchn_shape; + // For a ragged input tensor, the tensor shape should be + // the flatten shape of the whole batch + if (StateForModel()->IsInputRagged(input_name)) { + batchn_shape = std::vector{0}; + for (size_t idx = 0; idx < request_count; idx++) { + TRITONBACKEND_Input* input; + RESPOND_AND_SET_NULL_IF_ERROR( + &((*responses)[idx]), + TRITONBACKEND_RequestInput(requests[idx], input_name, &input)); + const int64_t* input_shape; + uint32_t input_dims_count; + RESPOND_AND_SET_NULL_IF_ERROR( + &((*responses)[idx]), + TRITONBACKEND_InputProperties(input, nullptr, nullptr, &input_shape, + &input_dims_count, nullptr, nullptr)); + + batchn_shape[0] += GetElementCount(input_shape, input_dims_count); + } + } else { + // The shape for the entire input batch, [total_batch_size, ...] + batchn_shape = + std::vector(input_shape, input_shape + input_dims_count); + if (max_batch_size != 0) { + batchn_shape[0] = total_batch_size; + } + } + + TRITONSERVER_MemoryType memory_type; + int64_t device_id = 0; + fastdeploy::Device device; + if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + memory_type = TRITONSERVER_MEMORY_GPU; + device_id = DeviceId(); + device = fastdeploy::Device::GPU; + } else { + memory_type = TRITONSERVER_MEMORY_CPU; + device = fastdeploy::Device::CPU; + } + input_tensors_[input_idx].Resize( + batchn_shape, ConvertDataTypeToFD(input_datatype), input_name, device); + collector->ProcessTensor( + input_name, + reinterpret_cast(input_tensors_[input_idx].MutableData()), + input_tensors_[input_idx].Nbytes(), memory_type, device_id); + } + + // Finalize... + *cuda_copy |= collector->Finalize(); + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors( + size_t total_batch_size, TRITONBACKEND_Request** requests, + const uint32_t request_count, + std::vector* responses) { + // r22.03 + // BackendOutputResponder responder( + // requests, request_count, responses, + // model_state_->TritonMemoryManager(), model_state_->MaxBatchSize() > 0, + // model_state_->EnablePinnedInput(), CudaStream()); + // r21.10 + BackendOutputResponder responder( + requests, request_count, responses, StateForModel()->MaxBatchSize(), + StateForModel()->TritonMemoryManager(), + StateForModel()->EnablePinnedOutput(), CudaStream()); + + // Use to hold string output contents + bool cuda_copy = false; + + // It is possible not to return all output! + // auto& model_outputs = StateForModel()->ModelOutputs(); + // size_t output_count = output_tensors_.size(); + // if (output_count != model_outputs.size()) { + // RETURN_IF_ERROR(TRITONSERVER_ErrorNew( + // TRITONSERVER_ERROR_INTERNAL, + // ("Retrieved output count is not equal to expected count."))); + // } + + for (auto& output_name : output_names_) { + int idx = GetInfoIndex(output_name, output_tensor_infos_); + responder.ProcessTensor( + output_tensors_[idx].name, ConvertFDType(output_tensors_[idx].dtype), + output_tensors_[idx].shape, + reinterpret_cast(output_tensors_[idx].MutableData()), + TRITONSERVER_MEMORY_CPU, 0); + } + + // Finalize and wait for any pending buffer copies. + cuda_copy |= responder.Finalize(); + +#ifdef TRITON_ENABLE_GPU + if (cuda_copy) { + cudaStreamSynchronize(stream_); + } +#endif // TRITON_ENABLE_GPU + return nullptr; +} + +///////////// + +extern "C" { + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Initialize( + TRITONBACKEND_Backend* backend) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR( + TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor) + " does not support '" + name + + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + } + + // The backend configuration may contain information needed by the + // ort backend, such as command-line arguments. + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR( + TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson(backend_config_message, + &buffer, &byte_size)); + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("backend configuration:\n") + buffer).c_str()); + + triton::common::TritonJson::Value backend_config; + TRITONSERVER_Error* err = nullptr; + if (byte_size != 0) { + err = backend_config.Parse(buffer, byte_size); + } + RETURN_IF_ERROR(err); + + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_Finalize( + TRITONBACKEND_Backend* backend) { + void* state = nullptr; + LOG_IF_ERROR(TRITONBACKEND_BackendState(backend, &state), + "failed to get backend state"); + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInitialize( + TRITONBACKEND_Model* model) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); + std::string name(cname); + + uint64_t version; + RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInitialize: ") + name + + " (version " + std::to_string(version) + ")") + .c_str()); + + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelFinalize( + TRITONBACKEND_Model* model) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelFinalize: delete model state"); + + delete model_state; + + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( + TRITONBACKEND_ModelInstance* instance) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); + std::string name(cname); + + int32_t device_id; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id)); + TRITONSERVER_InstanceGroupKind kind; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceKind(instance, &kind)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + + " (" + TRITONSERVER_InstanceGroupKindString(kind) + " device " + + std::to_string(device_id) + ")") + .c_str()); + + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( + TRITONBACKEND_ModelInstance* instance) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); + + delete instance_state; + + return nullptr; // success +} + +TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) { + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Suggested practice for this is to use only + // function-local and model-instance-specific state (obtained from + // 'instance'), which is what we do here. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = instance_state->StateForModel(); + + // This backend specifies BLOCKING execution policy. That means that + // we should not return from this function until execution is + // complete. Triton will automatically release 'instance' on return + // from this function so that it is again available to be used for + // another call to TRITONBACKEND_ModelInstanceExecute. + + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + (std::string("model ") + model_state->Name() + ", instance " + + instance_state->Name() + ", executing " + + std::to_string(request_count) + " requests") + .c_str()); + + // At this point we accept ownership of 'requests', which means that + // even if something goes wrong we must still return success from + // this function. If something does go wrong in processing a + // particular request then we send an error response just for the + // specific request. + instance_state->ProcessRequests(requests, request_count); + + return nullptr; // success +} + +} // extern "C" + +} // namespace fastdeploy_runtime +} // namespace backend +} // namespace triton diff --git a/serving/src/libtriton_fastdeploy.ldscript b/serving/src/libtriton_fastdeploy.ldscript new file mode 100644 index 000000000..fbe3520e5 --- /dev/null +++ b/serving/src/libtriton_fastdeploy.ldscript @@ -0,0 +1,30 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; \ No newline at end of file