diff --git a/examples/vision/segmentation/paddleseg/serving/models/paddleseg/1/README.md b/examples/vision/segmentation/paddleseg/serving/models/paddleseg/1/README.md new file mode 100644 index 000000000..42ae7e483 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/paddleseg/1/README.md @@ -0,0 +1,3 @@ +# PaddleSeg Pipeline + +The pipeline directory does not have model files, but a version number directory needs to be maintained. diff --git a/examples/vision/segmentation/paddleseg/serving/models/paddleseg/config.pbtxt b/examples/vision/segmentation/paddleseg/serving/models/paddleseg/config.pbtxt new file mode 100644 index 000000000..9571a5b91 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/paddleseg/config.pbtxt @@ -0,0 +1,67 @@ +platform: "ensemble" + +input [ + { + name: "INPUT" + data_type: TYPE_UINT8 + dims: [-1, -1, -1, 3 ] + } +] + +output [ + { + name: "SEG_RESULT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: 1 + input_map { + key: "preprocess_input" + value: "INPUT" + } + output_map { + key: "preprocess_output_1" + value: "RUNTIME_INPUT_1" + } + output_map { + key: "preprocess_output_2" + value: "POSTPROCESS_INPUT_2" + } + }, + { + model_name: "runtime" + model_version: 1 + input_map { + key: "x" + value: "RUNTIME_INPUT_1" + } + output_map { + key: "argmax_0.tmp_0" + value: "RUNTIME_OUTPUT" + } + }, + { + model_name: "postprocess" + model_version: 1 + input_map { + key: "post_input_1" + value: "RUNTIME_OUTPUT" + } + input_map { + key: "post_input_2" + value: "POSTPROCESS_INPUT_2" + } + output_map { + key: "post_output" + value: "SEG_RESULT" + } + } + ] +} + diff --git a/examples/vision/segmentation/paddleseg/serving/models/postprocess/1/model.py b/examples/vision/segmentation/paddleseg/serving/models/postprocess/1/model.py new file mode 100755 index 000000000..510aad6ea --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/postprocess/1/model.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import time +import os +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("postprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + self.output_dtype.append(dtype) + print("postprocess output names:", self.output_names) + + yaml_path = os.path.abspath(os.path.dirname(__file__)) + "/deploy.yaml" + self.postprocess_ = fd.vision.segmentation.PaddleSegPostprocessor( + yaml_path) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + for request in requests: + infer_outputs = pb_utils.get_input_tensor_by_name( + request, self.input_names[0]) + im_info = pb_utils.get_input_tensor_by_name(request, + self.input_names[1]) + infer_outputs = infer_outputs.as_numpy() + im_info = im_info.as_numpy() + for i in range(im_info.shape[0]): + im_info[i] = json.loads(im_info[i].decode('utf-8').replace( + "'", '"')) + + results = self.postprocess_.run([infer_outputs], im_info[0]) + r_str = fd.vision.utils.fd_result_to_json(results) + + r_np = np.array(r_str, dtype=np.object_) + out_tensor = pb_utils.Tensor(self.output_names[0], r_np) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor, ]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/segmentation/paddleseg/serving/models/postprocess/config.pbtxt b/examples/vision/segmentation/paddleseg/serving/models/postprocess/config.pbtxt new file mode 100644 index 000000000..81f31ba08 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/postprocess/config.pbtxt @@ -0,0 +1,30 @@ +name: "postprocess" +backend: "python" + +input [ + { + name: "post_input_1" + data_type: TYPE_INT32 + dims: [-1, -1, -1] + }, + { + name: "post_input_2" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +output [ + { + name: "post_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/deploy.yaml b/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/deploy.yaml new file mode 100644 index 000000000..6d33e5009 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/deploy.yaml @@ -0,0 +1,12 @@ +Deploy: + input_shape: + - -1 + - 3 + - -1 + - -1 + model: model.pdmodel + output_dtype: int32 + output_op: argmax + params: model.pdiparams + transforms: + - type: Normalize diff --git a/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/model.py b/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/model.py new file mode 100644 index 000000000..48a72d6fa --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/preprocess/1/model.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import numpy as np +import os + +import fastdeploy as fd + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args['model_config']) + print("model_config:", self.model_config) + + self.input_names = [] + for input_config in self.model_config["input"]: + self.input_names.append(input_config["name"]) + print("preprocess input names:", self.input_names) + + self.output_names = [] + self.output_dtype = [] + for output_config in self.model_config["output"]: + self.output_names.append(output_config["name"]) + # dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + # self.output_dtype.append(dtype) + self.output_dtype.append(output_config["data_type"]) + print("preprocess output names:", self.output_names) + + # init PaddleSegPreprocess class + yaml_path = os.path.abspath(os.path.dirname(__file__)) + "/deploy.yaml" + self.preprocess_ = fd.vision.segmentation.PaddleSegPreprocessor( + yaml_path) + #if args['model_instance_kind'] == 'GPU': + # device_id = int(args['model_instance_device_id']) + # self.preprocess_.use_gpu(device_id) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + for request in requests: + data = pb_utils.get_input_tensor_by_name(request, + self.input_names[0]) + data = data.as_numpy() + outputs, im_info = self.preprocess_.run(data) + + # PaddleSeg preprocess has two outputs + dlpack_tensor = outputs[0].to_dlpack() + output_tensor_0 = pb_utils.Tensor.from_dlpack(self.output_names[0], + dlpack_tensor) + output_tensor_1 = pb_utils.Tensor( + self.output_names[1], np.array( + [im_info], dtype=np.object_)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor_0, output_tensor_1]) + responses.append(inference_response) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') diff --git a/examples/vision/segmentation/paddleseg/serving/models/preprocess/config.pbtxt b/examples/vision/segmentation/paddleseg/serving/models/preprocess/config.pbtxt new file mode 100644 index 000000000..01cb94869 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/preprocess/config.pbtxt @@ -0,0 +1,34 @@ +name: "preprocess" +backend: "python" + +input [ + { + name: "preprocess_input" + data_type: TYPE_UINT8 + dims: [-1, -1, -1, 3 ] + } +] + +output [ + { + name: "preprocess_output_1" + data_type: TYPE_FP32 + dims: [-1, 3, -1, -1 ] + }, + { + name: "preprocess_output_2" + data_type: TYPE_STRING + dims: [ -1] + } +] + +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use CPU, GPU inference option is:KIND_GPU + kind: KIND_CPU + # The instance is deployed on the 0th GPU card + # gpus: [0] + } +] diff --git a/examples/vision/segmentation/paddleseg/serving/models/runtime/1/README.md b/examples/vision/segmentation/paddleseg/serving/models/runtime/1/README.md new file mode 100644 index 000000000..1e5d914b4 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/runtime/1/README.md @@ -0,0 +1,5 @@ +# Runtime Directory + +This directory holds the model files. +Paddle models must be model.pdmodel and model.pdiparams files. +ONNX models must be model.onnx files. diff --git a/examples/vision/segmentation/paddleseg/serving/models/runtime/config.pbtxt b/examples/vision/segmentation/paddleseg/serving/models/runtime/config.pbtxt new file mode 100644 index 000000000..bd145c590 --- /dev/null +++ b/examples/vision/segmentation/paddleseg/serving/models/runtime/config.pbtxt @@ -0,0 +1,60 @@ +# optional, If name is specified it must match the name of the model repository directory containing the model. +name: "runtime" +backend: "fastdeploy" + +# Input configuration of the model +input [ + { + # input name + name: "x" + # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING + data_type: TYPE_FP32 + # input shape + dims: [-1, 3, -1, -1 ] + } +] + +# The output of the model is configured in the same format as the input +output [ + { + name: "argmax_0.tmp_0" + data_type: TYPE_INT32 + dims: [ -1, -1, -1 ] + } +] + +# Number of instances of the model +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + kind: KIND_GPU + # The instance is deployed on the 0th GPU card + gpus: [0] + } +] + +optimization { + execution_accelerators { + gpu_execution_accelerator : [ { + # use TRT engine + name: "paddle", + #name: "tensorrt", + # use fp16 on TRT engine + parameters { key: "precision" value: "trt_fp32" } + }, + { + name: "min_shape" + parameters { key: "x" value: "1 3 256 256" } + }, + { + name: "opt_shape" + parameters { key: "x" value: "1 3 1024 1024" } + }, + { + name: "max_shape" + parameters { key: "x" value: "16 3 2048 2048" } + } + ] +}}