mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-13 12:23:55 +08:00
[Serving] Fixed preprocess&&postprocess in YOLOv5 Serving (#874)
* add onnx_ort_runtime demo * rm in requirements * support batch eval * fixed MattingResults bug * move assignment for DetectionResult * integrated x2paddle * add model convert readme * update readme * re-lint * add processor api * Add MattingResult Free * change valid_cpu_backends order * add ppocr benchmark * mv bs from 64 to 32 * fixed quantize.md * fixed quantize bugs * Add Monitor for benchmark * update mem monitor * Set trt_max_batch_size default 1 * fixed ocr benchmark bug * support yolov5 in serving * Fixed yolov5 serving * Fixed postprocess Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
1
examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py
Normal file → Executable file
1
examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py
Normal file → Executable file
@@ -84,7 +84,6 @@ class TritonPythonModel:
|
|||||||
be the same as `requests`
|
be the same as `requests`
|
||||||
"""
|
"""
|
||||||
responses = []
|
responses = []
|
||||||
# print("num:", len(requests), flush=True)
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
infer_outputs = pb_utils.get_input_tensor_by_name(
|
infer_outputs = pb_utils.get_input_tensor_by_name(
|
||||||
request, self.input_names[0])
|
request, self.input_names[0])
|
||||||
|
@@ -61,31 +61,7 @@ class TritonPythonModel:
|
|||||||
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
self.output_dtype.append(dtype)
|
self.output_dtype.append(dtype)
|
||||||
print("postprocess output names:", self.output_names)
|
print("postprocess output names:", self.output_names)
|
||||||
|
self.postprocessor_ = fd.vision.detection.YOLOv5Postprocessor()
|
||||||
def yolov5_postprocess(self, infer_outputs, im_infos):
|
|
||||||
"""
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
infer_outputs : numpy.array
|
|
||||||
Contains the batch of inference results
|
|
||||||
im_infos : numpy.array(b'{}')
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
numpy.array
|
|
||||||
yolov5 postprocess result
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for i_batch in range(len(im_infos)):
|
|
||||||
new_infer_output = infer_outputs[i_batch:i_batch + 1]
|
|
||||||
new_im_info = im_infos[i_batch].decode('utf-8').replace("'", '"')
|
|
||||||
new_im_info = json.loads(new_im_info)
|
|
||||||
|
|
||||||
result = fd.vision.detection.YOLOv5.postprocess(
|
|
||||||
[new_infer_output, ], new_im_info)
|
|
||||||
|
|
||||||
r_str = fd.vision.utils.fd_result_to_json(result)
|
|
||||||
results.append(r_str)
|
|
||||||
return np.array(results, dtype=np.object)
|
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""`execute` must be implemented in every Python model. `execute`
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
@@ -107,7 +83,6 @@ class TritonPythonModel:
|
|||||||
be the same as `requests`
|
be the same as `requests`
|
||||||
"""
|
"""
|
||||||
responses = []
|
responses = []
|
||||||
# print("num:", len(requests), flush=True)
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
infer_outputs = pb_utils.get_input_tensor_by_name(
|
infer_outputs = pb_utils.get_input_tensor_by_name(
|
||||||
request, self.input_names[0])
|
request, self.input_names[0])
|
||||||
@@ -115,10 +90,15 @@ class TritonPythonModel:
|
|||||||
self.input_names[1])
|
self.input_names[1])
|
||||||
infer_outputs = infer_outputs.as_numpy()
|
infer_outputs = infer_outputs.as_numpy()
|
||||||
im_infos = im_infos.as_numpy()
|
im_infos = im_infos.as_numpy()
|
||||||
|
for i in range(im_infos.shape[0]):
|
||||||
|
im_infos[i] = json.loads(im_infos[i].decode('utf-8').replace(
|
||||||
|
"'", '"'))
|
||||||
|
|
||||||
results = self.yolov5_postprocess(infer_outputs, im_infos)
|
results = self.postprocessor_.run([infer_outputs], im_infos)
|
||||||
|
r_str = fd.vision.utils.fd_result_to_json(results)
|
||||||
|
r_np = np.array(r_str, dtype=np.object)
|
||||||
|
|
||||||
out_tensor = pb_utils.Tensor(self.output_names[0], results)
|
out_tensor = pb_utils.Tensor(self.output_names[0], r_np)
|
||||||
inference_response = pb_utils.InferenceResponse(
|
inference_response = pb_utils.InferenceResponse(
|
||||||
output_tensors=[out_tensor, ])
|
output_tensors=[out_tensor, ])
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
|
@@ -61,21 +61,7 @@ class TritonPythonModel:
|
|||||||
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
self.output_dtype.append(dtype)
|
self.output_dtype.append(dtype)
|
||||||
print("preprocess output names:", self.output_names)
|
print("preprocess output names:", self.output_names)
|
||||||
|
self.preprocessor_ = fd.vision.detection.YOLOv5Preprocessor()
|
||||||
def yolov5_preprocess(self, input_data):
|
|
||||||
"""
|
|
||||||
According to Triton input, the preprocessing results of YoloV5 model are obtained.
|
|
||||||
"""
|
|
||||||
im_infos = []
|
|
||||||
pre_outputs = []
|
|
||||||
for i_batch in input_data:
|
|
||||||
pre_output, im_info = fd.vision.detection.YOLOv5.preprocess(
|
|
||||||
i_batch)
|
|
||||||
pre_outputs.append(pre_output)
|
|
||||||
im_infos.append(im_info)
|
|
||||||
im_infos = np.array(im_infos, dtype=np.object)
|
|
||||||
pre_outputs = np.concatenate(pre_outputs, axis=0)
|
|
||||||
return pre_outputs, im_infos
|
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""`execute` must be implemented in every Python model. `execute`
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
@@ -97,18 +83,21 @@ class TritonPythonModel:
|
|||||||
be the same as `requests`
|
be the same as `requests`
|
||||||
"""
|
"""
|
||||||
responses = []
|
responses = []
|
||||||
# print("num:", len(requests), flush=True)
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
data = pb_utils.get_input_tensor_by_name(request,
|
data = pb_utils.get_input_tensor_by_name(request,
|
||||||
self.input_names[0])
|
self.input_names[0])
|
||||||
data = data.as_numpy()
|
data = data.as_numpy()
|
||||||
outputs = self.yolov5_preprocess(data)
|
outputs, im_infos = self.preprocessor_.run(data)
|
||||||
output_tensors = []
|
|
||||||
for idx, output in enumerate(outputs):
|
# YOLOv5 preprocess has two output
|
||||||
output_tensors.append(
|
dlpack_tensor = outputs[0].to_dlpack()
|
||||||
pb_utils.Tensor(self.output_names[idx], output))
|
output_tensor_0 = pb_utils.Tensor.from_dlpack(self.output_names[0],
|
||||||
|
dlpack_tensor)
|
||||||
|
output_tensor_1 = pb_utils.Tensor(
|
||||||
|
self.output_names[1], np.array(
|
||||||
|
im_infos, dtype=np.object))
|
||||||
inference_response = pb_utils.InferenceResponse(
|
inference_response = pb_utils.InferenceResponse(
|
||||||
output_tensors=output_tensors)
|
output_tensors=[output_tensor_0, output_tensor_1])
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user