diff --git a/cmake/paddle_inference.cmake b/cmake/paddle_inference.cmake old mode 100644 new mode 100755 diff --git a/examples/vision/detection/yolov5seg/README.md b/examples/vision/detection/yolov5seg/README.md new file mode 100644 index 000000000..e35838c23 --- /dev/null +++ b/examples/vision/detection/yolov5seg/README.md @@ -0,0 +1,27 @@ +# YOLOv5Seg准备部署模型 + +- YOLOv5Seg v7.0部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v7.0),和[基于COCO的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v7.0) + - (1)[官方库](https://github.com/ultralytics/yolov5/releases/tag/v7.0)提供的*.onnx可直接进行部署; + - (2)开发者基于自己数据训练的YOLOv5Seg v7.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,完成部署。 + + +## 下载预训练ONNX模型 + +为了方便开发者的测试,下面提供了YOLOv5Seg导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库) +| 模型 | 大小 | 精度 | 备注 | +|:---------------------------------------------------------------- |:----- |:----- |:----- | +| [YOLOv5n-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-seg.onnx) | 7.7MB | 27.6% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License | +| [YOLOv5s-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx) | 30MB | 37.6% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License | +| [YOLOv5m-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5m-seg.onnx) | 84MB | 45.0% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License | +| [YOLOv5l-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5l-seg.onnx) | 183MB | 49.0% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License | +| [YOLOv5x-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5x-seg.onnx) | 339MB | 50.7% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License | + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) + +## 版本说明 + +- 本版本文档和代码基于[YOLOv5 v7.0](https://github.com/ultralytics/yolov5/tree/v7.0) 编写 diff --git a/examples/vision/detection/yolov5seg/cpp/CMakeLists.txt b/examples/vision/detection/yolov5seg/cpp/CMakeLists.txt new file mode 100644 index 000000000..6610d04d2 --- /dev/null +++ b/examples/vision/detection/yolov5seg/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# Specify the fastdeploy library path after downloading and decompression +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# Add FastDeploy dependent header files +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# Add FastDeploy library dependencies +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/yolov5seg/cpp/README.md b/examples/vision/detection/yolov5seg/cpp/README.md new file mode 100644 index 000000000..486d36fbf --- /dev/null +++ b/examples/vision/detection/yolov5seg/cpp/README.md @@ -0,0 +1,74 @@ +# YOLOv5Seg C++部署示例 + +本目录下提供`infer.cc`快速完成YOLOv5Seg在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证FastDeploy版本1.0.3以上(x.x.x>=1.0.3) + +```bash +mkdir build +cd build +# 下载 FastDeploy 预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用 +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz +tar xvf fastdeploy-linux-x64-x.x.x.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x +make -j + +# 1. 下载官方转换好的 YOLOv5Seg ONNX 模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# CPU推理 +./infer_demo yolov5s-seg.onnx 000000014439.jpg 0 +# GPU推理 +./infer_demo yolov5s-seg.onnx 000000014439.jpg 1 +# GPU上TensorRT推理 +./infer_demo yolov5s-seg.onnx 000000014439.jpg 2 +``` +运行完成可视化结果如下图所示 + + + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## YOLOv5Seg C++接口 + +### YOLOv5Seg类 + +```c++ +fastdeploy::vision::detection::YOLOv5Seg( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX) +``` + +YOLOv5Seg模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +#### Predict函数 + +```c++ +YOLOv5Seg::Predict(const cv::Mat& img, DetectionResult* result) +``` + +**参数** + +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/detection/yolov5seg/cpp/infer.cc b/examples/vision/detection/yolov5seg/cpp/infer.cc new file mode 100644 index 000000000..c28907028 --- /dev/null +++ b/examples/vision/detection/yolov5seg/cpp/infer.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 640, 640}); + auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./yolov5.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/yolov5seg/python/README.md b/examples/vision/detection/yolov5seg/python/README.md new file mode 100644 index 000000000..e09014dec --- /dev/null +++ b/examples/vision/detection/yolov5seg/python/README.md @@ -0,0 +1,67 @@ +# YOLOv5Seg Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成YOLOv5Seg在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd examples/vision/detection/yolov5seg/python/ + +#下载yolov5seg模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# CPU推理 +python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device cpu +# GPU推理 +python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device gpu +# GPU上使用TensorRT推理 +python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 + + + +## YOLOv5Seg Python接口 + +```python +fastdeploy.vision.detection.YOLOv5Seg(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX) +``` + +YOLOv5Seg模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX + +### predict函数 + +```python +YOLOv5Seg.predict(image_data) +``` + +模型预测结口,输入图像直接输出检测结果。 + +**参数** + +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 + +**返回** + +> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +## 其它文档 + +- [YOLOv5Seg 模型介绍](..) +- [YOLOv5Seg C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/detection/yolov5seg/python/infer.py b/examples/vision/detection/yolov5seg/python/infer.py new file mode 100644 index 000000000..34f9b7f14 --- /dev/null +++ b/examples/vision/detection/yolov5seg/python/infer.py @@ -0,0 +1,56 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", default=None, help="Path of yolov5seg model.") + parser.add_argument( + "--image", default=None, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("images", [1, 3, 640, 640]) + return option + + +args = parse_arguments() + +# Configure runtime, load model +runtime_option = build_option(args) +model = fd.vision.detection.YOLOv5Seg( + args.model, runtime_option=runtime_option) + +# Predicting image +if args.image is None: + image = fd.utils.get_detection_test_image() +else: + image = args.image +im = cv2.imread(image) +result = model.predict(im) + +# Visualization +vis_im = fd.vision.vis_detection(im, result) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/runtime/backends/paddle/paddle_backend.h b/fastdeploy/runtime/backends/paddle/paddle_backend.h old mode 100644 new mode 100755 diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h old mode 100644 new mode 100755 index 0714a9766..867de58cb --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -22,6 +22,7 @@ #include "fastdeploy/vision/detection/contrib/scaledyolov4.h" #include "fastdeploy/vision/detection/contrib/yolor.h" #include "fastdeploy/vision/detection/contrib/yolov5/yolov5.h" +#include "fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h" #include "fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h" #include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov6.h" diff --git a/fastdeploy/vision/common/result.cc b/fastdeploy/vision/common/result.cc index 9fc01e565..446a39699 100755 --- a/fastdeploy/vision/common/result.cc +++ b/fastdeploy/vision/common/result.cc @@ -48,7 +48,7 @@ void Mask::Reserve(int size) { data.reserve(size); } void Mask::Resize(int size) { data.resize(size); } void Mask::Clear() { - std::vector().swap(data); + std::vector().swap(data); std::vector().swap(shape); } diff --git a/fastdeploy/vision/common/result.h b/fastdeploy/vision/common/result.h index b6ff1fbf7..c68f6d4cf 100755 --- a/fastdeploy/vision/common/result.h +++ b/fastdeploy/vision/common/result.h @@ -67,7 +67,7 @@ struct FASTDEPLOY_DECL ClassifyResult : public BaseResult { */ struct FASTDEPLOY_DECL Mask : public BaseResult { /// Mask data buffer - std::vector data; + std::vector data; /// Shape of mask std::vector shape; // (H,W) ... ResultType type = ResultType::MASK; @@ -107,7 +107,7 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult { /** \brief For instance segmentation model, `masks` is the predict mask for all the deteced objects */ std::vector masks; - //// Shows if the DetectionResult has mask + /// Shows if the DetectionResult has mask bool contain_masks = false; ResultType type = ResultType::DETECTION; diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.cc b/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.cc new file mode 100755 index 000000000..50bcaba5c --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv5SegPostprocessor::YOLOv5SegPostprocessor() { + conf_threshold_ = 0.25; + nms_threshold_ = 0.5; + mask_threshold_ = 0.5; + multi_label_ = true; + max_wh_ = 7680.0; + mask_nums_ = 32; +} + +bool YOLOv5SegPostprocessor::Run( + const std::vector& tensors, std::vector* results, + const std::vector>>& ims_info) { + int batch = tensors[0].shape[0]; + + results->resize(batch); + + for (size_t bs = 0; bs < batch; ++bs) { + // store mask information + std::vector> mask_embeddings; + (*results)[bs].Clear(); + if (multi_label_) { + (*results)[bs].Reserve(tensors[0].shape[1] * + (tensors[0].shape[2] - mask_nums_ - 5)); + } else { + (*results)[bs].Reserve(tensors[0].shape[1]); + } + if (tensors[0].dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + const float* data = reinterpret_cast(tensors[0].Data()) + + bs * tensors[0].shape[1] * tensors[0].shape[2]; + for (size_t i = 0; i < tensors[0].shape[1]; ++i) { + int s = i * tensors[0].shape[2]; + float cls_conf = data[s + 4]; + float confidence = data[s + 4]; + std::vector mask_embedding( + data + s + tensors[0].shape[2] - mask_nums_, + data + s + tensors[0].shape[2]); + for (size_t k = 0; k < mask_embedding.size(); ++k) { + mask_embedding[k] *= cls_conf; + } + if (multi_label_) { + for (size_t j = 5; j < tensors[0].shape[2] - mask_nums_; ++j) { + confidence = data[s + 4]; + const float* class_score = data + s + j; + confidence *= (*class_score); + // filter boxes by conf_threshold + if (confidence <= conf_threshold_) { + continue; + } + int32_t label_id = std::distance(data + s + 5, class_score); + + // convert from [x, y, w, h] to [x1, y1, x2, y2] + (*results)[bs].boxes.emplace_back(std::array{ + data[s] - data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh_, + data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh_}); + (*results)[bs].label_ids.push_back(label_id); + (*results)[bs].scores.push_back(confidence); + // TODO(wangjunjie06): No zero copy + mask_embeddings.push_back(mask_embedding); + } + } else { + const float* max_class_score = std::max_element( + data + s + 5, data + s + tensors[0].shape[2] - mask_nums_); + confidence *= (*max_class_score); + // filter boxes by conf_threshold + if (confidence <= conf_threshold_) { + continue; + } + int32_t label_id = std::distance(data + s + 5, max_class_score); + // convert from [x, y, w, h] to [x1, y1, x2, y2] + (*results)[bs].boxes.emplace_back(std::array{ + data[s] - data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh_, + data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh_, + data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh_}); + (*results)[bs].label_ids.push_back(label_id); + (*results)[bs].scores.push_back(confidence); + mask_embeddings.push_back(mask_embedding); + } + } + + if ((*results)[bs].boxes.size() == 0) { + return true; + } + // get box index after nms + std::vector index; + utils::NMS(&((*results)[bs]), nms_threshold_, &index); + + // deal with mask + // step1: MatMul, (box_nums * 32) x (32 * 160 * 160) = box_nums * 160 * 160 + // step2: Sigmoid + // step3: Resize to original image size + // step4: Select pixels greater than threshold and crop + (*results)[bs].contain_masks = true; + (*results)[bs].masks.resize((*results)[bs].boxes.size()); + const float* data_mask = + reinterpret_cast(tensors[1].Data()) + + bs * tensors[1].shape[1] * tensors[1].shape[2] * tensors[1].shape[3]; + cv::Mat mask_proto = + cv::Mat(tensors[1].shape[1], tensors[1].shape[2] * tensors[1].shape[3], + CV_32FC(1), const_cast(data_mask)); + // vector to cv::Mat for MatMul + // after push_back, Mat of m*n becomes (m + 1) * n + cv::Mat mask_proposals; + for (size_t i = 0; i < index.size(); ++i) { + mask_proposals.push_back(cv::Mat(mask_embeddings[index[i]]).t()); + } + cv::Mat matmul_result = (mask_proposals * mask_proto).t(); + cv::Mat masks = matmul_result.reshape( + (*results)[bs].boxes.size(), {static_cast(tensors[1].shape[2]), + static_cast(tensors[1].shape[3])}); + // split for boxes nums + std::vector mask_channels; + cv::split(masks, mask_channels); + + // scale the boxes to the origin image shape + auto iter_out = ims_info[bs].find("output_shape"); + auto iter_ipt = ims_info[bs].find("input_shape"); + FDASSERT(iter_out != ims_info[bs].end() && iter_ipt != ims_info[bs].end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + float scale = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * scale) / 2; + float pad_w = (out_w - ipt_w * scale) / 2; + // for mask + float pad_h_mask = (float)pad_h / out_h * tensors[1].shape[2]; + float pad_w_mask = (float)pad_w / out_w * tensors[1].shape[3]; + for (size_t i = 0; i < (*results)[bs].boxes.size(); ++i) { + int32_t label_id = ((*results)[bs].label_ids)[i]; + // clip box + (*results)[bs].boxes[i][0] = + (*results)[bs].boxes[i][0] - max_wh_ * label_id; + (*results)[bs].boxes[i][1] = + (*results)[bs].boxes[i][1] - max_wh_ * label_id; + (*results)[bs].boxes[i][2] = + (*results)[bs].boxes[i][2] - max_wh_ * label_id; + (*results)[bs].boxes[i][3] = + (*results)[bs].boxes[i][3] - max_wh_ * label_id; + (*results)[bs].boxes[i][0] = + std::max(((*results)[bs].boxes[i][0] - pad_w) / scale, 0.0f); + (*results)[bs].boxes[i][1] = + std::max(((*results)[bs].boxes[i][1] - pad_h) / scale, 0.0f); + (*results)[bs].boxes[i][2] = + std::max(((*results)[bs].boxes[i][2] - pad_w) / scale, 0.0f); + (*results)[bs].boxes[i][3] = + std::max(((*results)[bs].boxes[i][3] - pad_h) / scale, 0.0f); + (*results)[bs].boxes[i][0] = std::min((*results)[bs].boxes[i][0], ipt_w); + (*results)[bs].boxes[i][1] = std::min((*results)[bs].boxes[i][1], ipt_h); + (*results)[bs].boxes[i][2] = std::min((*results)[bs].boxes[i][2], ipt_w); + (*results)[bs].boxes[i][3] = std::min((*results)[bs].boxes[i][3], ipt_h); + // deal with mask + cv::Mat dest, mask; + // sigmoid + cv::exp(-mask_channels[i], dest); + dest = 1.0 / (1.0 + dest); + // crop mask for feature map + int x1 = static_cast(pad_w_mask); + int y1 = static_cast(pad_h_mask); + int x2 = static_cast(tensors[1].shape[3] - pad_w_mask); + int y2 = static_cast(tensors[1].shape[2] - pad_h_mask); + cv::Rect roi(x1, y1, x2 - x1, y2 - y1); + dest = dest(roi); + cv::resize(dest, mask, cv::Size(ipt_w, ipt_h), 0, 0, cv::INTER_LINEAR); + // crop mask for source img + int x1_src = static_cast(round((*results)[bs].boxes[i][0])); + int y1_src = static_cast(round((*results)[bs].boxes[i][1])); + int x2_src = static_cast(round((*results)[bs].boxes[i][2])); + int y2_src = static_cast(round((*results)[bs].boxes[i][3])); + cv::Rect roi_src(x1_src, y1_src, x2_src - x1_src, y2_src - y1_src); + mask = mask(roi_src); + mask = mask > mask_threshold_; + // save mask in DetectionResult + int keep_mask_h = y2_src - y1_src; + int keep_mask_w = x2_src - x1_src; + int keep_mask_numel = keep_mask_h * keep_mask_w; + (*results)[bs].masks[i].Resize(keep_mask_numel); + (*results)[bs].masks[i].shape = {keep_mask_h, keep_mask_w}; + uint8_t* keep_mask_ptr = + reinterpret_cast((*results)[bs].masks[i].Data()); + std::memcpy(keep_mask_ptr, reinterpret_cast(mask.ptr()), + keep_mask_numel * sizeof(uint8_t)); + } + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.h b/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.h new file mode 100755 index 000000000..24f078542 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Postprocessor object for YOLOv5Seg serials model. + */ +class FASTDEPLOY_DECL YOLOv5SegPostprocessor { + public: + /** \brief Create a postprocessor instance for YOLOv5Seg serials model + */ + YOLOv5SegPostprocessor(); + + /** \brief Process the result of runtime and fill to DetectionResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of detection + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* results, + const std::vector>>& ims_info); + + /// Set conf_threshold, default 0.25 + void SetConfThreshold(const float& conf_threshold) { + conf_threshold_ = conf_threshold; + } + + /// Get conf_threshold, default 0.25 + float GetConfThreshold() const { return conf_threshold_; } + + /// Set nms_threshold, default 0.5 + void SetNMSThreshold(const float& nms_threshold) { + nms_threshold_ = nms_threshold; + } + + /// Get nms_threshold, default 0.5 + float GetNMSThreshold() const { return nms_threshold_; } + + /// Set multi_label, set true for eval, default true + void SetMultiLabel(bool multi_label) { + multi_label_ = multi_label; + } + + /// Get multi_label, default true + bool GetMultiLabel() const { return multi_label_; } + + protected: + float conf_threshold_; + float nms_threshold_; + bool multi_label_; + float max_wh_; + // channel nums of masks + int mask_nums_; + // mask threshold + float mask_threshold_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.cc b/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.cc new file mode 100644 index 000000000..b880ed337 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv5SegPreprocessor::YOLOv5SegPreprocessor() { + size_ = {640, 640}; + padding_value_ = {114.0, 114.0, 114.0}; + is_mini_pad_ = false; + is_no_pad_ = false; + is_scale_up_ = true; + stride_ = 32; + max_wh_ = 7680.0; +} + +void YOLOv5SegPreprocessor::LetterBox(FDMat* mat) { + float scale = + std::min(size_[1] * 1.0 / mat->Height(), size_[0] * 1.0 / mat->Width()); + if (!is_scale_up_) { + scale = std::min(scale, 1.0f); + } + + int resize_h = int(round(mat->Height() * scale)); + int resize_w = int(round(mat->Width() * scale)); + + int pad_w = size_[0] - resize_w; + int pad_h = size_[1] - resize_h; + if (is_mini_pad_) { + pad_h = pad_h % stride_; + pad_w = pad_w % stride_; + } else if (is_no_pad_) { + pad_h = 0; + pad_w = 0; + resize_h = size_[1]; + resize_w = size_[0]; + } + if (std::fabs(scale - 1.0f) > 1e-06) { + Resize::Run(mat, resize_w, resize_h); + } + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, padding_value_); + } +} + +bool YOLOv5SegPreprocessor::Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info) { + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + // yolov5seg's preprocess steps + // 1. letterbox + // 2. convert_and_permute(swap_rb=true) + LetterBox(mat); + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + ConvertAndPermute::Run(mat, alpha, beta, true); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + mat->ShareWithTensor(output); + output->ExpandDim(0); // reshape to n, h, w, c + return true; +} + +bool YOLOv5SegPreprocessor::Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info) { + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + ims_info->resize(images->size()); + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + if (!Preprocess(&(*images)[i], &tensors[i], &(*ims_info)[i])) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } + + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.h b/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.h new file mode 100644 index 000000000..241bdda6b --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.h @@ -0,0 +1,113 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Preprocessor object for YOLOv5Seg serials model. + */ +class FASTDEPLOY_DECL YOLOv5SegPreprocessor { + public: + /** \brief Create a preprocessor instance for YOLOv5Seg serials model + */ + YOLOv5SegPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info); + + /// Set target size, tuple of (width, height), default size = {640, 640} + void SetSize(const std::vector& size) { size_ = size; } + + /// Get target size, tuple of (width, height), default size = {640, 640} + std::vector GetSize() const { return size_; } + + /// Set padding value, size should be the same as channels + void SetPaddingValue(const std::vector& padding_value) { + padding_value_ = padding_value; + } + + /// Get padding value, size should be the same as channels + std::vector GetPaddingValue() const { return padding_value_; } + + /// Set is_scale_up, if is_scale_up is false, the input image only + /// can be zoom out, the maximum resize scale cannot exceed 1.0, default true + void SetScaleUp(bool is_scale_up) { + is_scale_up_ = is_scale_up; + } + + /// Get is_scale_up, default true + bool GetScaleUp() const { return is_scale_up_; } + + /// Set is_mini_pad, pad to the minimum rectange + /// which height and width is times of stride + void SetMiniPad(bool is_mini_pad) { + is_mini_pad_ = is_mini_pad; + } + + /// Get is_mini_pad, default false + bool GetMiniPad() const { return is_mini_pad_; } + + /// Set padding stride, only for mini_pad mode + void SetStride(int stride) { + stride_ = stride; + } + + /// Get padding stride, default 32 + bool GetStride() const { return stride_; } + + protected: + bool Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info); + + void LetterBox(FDMat* mat); + + // target size, tuple of (width, height), default size = {640, 640} + std::vector size_; + + // padding value, size should be the same as channels + std::vector padding_value_; + + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad_; + + // while is_mini_pad = false and is_no_pad = true, + // will resize the image to the set size + bool is_no_pad_; + + // if is_scale_up is false, the input image only can be zoom out, + // the maximum resize scale cannot exceed 1.0 + bool is_scale_up_; + + // padding stride, for is_mini_pad + int stride_; + + // for offseting the boxes by classes when using NMS + float max_wh_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.cc b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.cc new file mode 100644 index 000000000..716c8d253 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +YOLOv5Seg::YOLOv5Seg(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool YOLOv5Seg::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool YOLOv5Seg::Predict(const cv::Mat& im, DetectionResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool YOLOv5Seg::BatchPredict(const std::vector& images, std::vector* results) { + std::vector>> ims_info; + std::vector fd_images = WrapMat(images); + + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; + } + + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)) { + FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; + return false; + } + + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h new file mode 100755 index 000000000..ca4549957 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/detection/contrib/yolov5seg/preprocessor.h" +#include "fastdeploy/vision/detection/contrib/yolov5seg/postprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace detection { +/*! @brief YOLOv5Seg model object used when to load a YOLOv5Seg model exported by YOLOv5. + */ +class FASTDEPLOY_DECL YOLOv5Seg : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g ./yolov5seg.onnx + * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + YOLOv5Seg(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + std::string ModelName() const { return "yolov5seg"; } + + /** \brief Predict the detection result for an input image + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& img, DetectionResult* result); + + /** \brief Predict the detection results for a batch of input images + * + * \param[in] imgs, The input image list, each element comes from cv::imread() + * \param[in] results The output detection result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + /// Get preprocessor reference of YOLOv5Seg + virtual YOLOv5SegPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of YOLOv5Seg + virtual YOLOv5SegPostprocessor& GetPostprocessor() { + return postprocessor_; + } + + protected: + bool Initialize(); + YOLOv5SegPreprocessor preprocessor_; + YOLOv5SegPostprocessor postprocessor_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg_pybind.cc b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg_pybind.cc new file mode 100755 index 000000000..0306c7b02 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg_pybind.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindYOLOv5Seg(pybind11::module& m) { + pybind11::class_( + m, "YOLOv5SegPreprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::YOLOv5SegPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + std::vector>> ims_info; + if (!self.Run(&images, &outputs, &ims_info)) { + throw std::runtime_error("Failed to preprocess the input data in PaddleClasPreprocessor."); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return make_pair(outputs, ims_info); + }) + .def_property("size", &vision::detection::YOLOv5SegPreprocessor::GetSize, &vision::detection::YOLOv5SegPreprocessor::SetSize) + .def_property("padding_value", &vision::detection::YOLOv5SegPreprocessor::GetPaddingValue, &vision::detection::YOLOv5SegPreprocessor::SetPaddingValue) + .def_property("is_scale_up", &vision::detection::YOLOv5SegPreprocessor::GetScaleUp, &vision::detection::YOLOv5SegPreprocessor::SetScaleUp) + .def_property("is_mini_pad", &vision::detection::YOLOv5SegPreprocessor::GetMiniPad, &vision::detection::YOLOv5SegPreprocessor::SetMiniPad) + .def_property("stride", &vision::detection::YOLOv5SegPreprocessor::GetStride, &vision::detection::YOLOv5SegPreprocessor::SetStride); + + pybind11::class_( + m, "YOLOv5SegPostprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::YOLOv5SegPostprocessor& self, std::vector& inputs, + const std::vector>>& ims_info) { + std::vector results; + if (!self.Run(inputs, &results, ims_info)) { + throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5SegPostprocessor."); + } + return results; + }) + .def("run", [](vision::detection::YOLOv5SegPostprocessor& self, std::vector& input_array, + const std::vector>>& ims_info) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, ims_info)) { + throw std::runtime_error("Failed to postprocess the runtime result in YOLOv5SegPostprocessor."); + } + return results; + }) + .def_property("conf_threshold", &vision::detection::YOLOv5SegPostprocessor::GetConfThreshold, &vision::detection::YOLOv5SegPostprocessor::SetConfThreshold) + .def_property("nms_threshold", &vision::detection::YOLOv5SegPostprocessor::GetNMSThreshold, &vision::detection::YOLOv5SegPostprocessor::SetNMSThreshold) + .def_property("multi_label", &vision::detection::YOLOv5SegPostprocessor::GetMultiLabel, &vision::detection::YOLOv5SegPostprocessor::SetMultiLabel); + + pybind11::class_(m, "YOLOv5Seg") + .def(pybind11::init()) + .def("predict", + [](vision::detection::YOLOv5Seg& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(mat, &res); + return res; + }) + .def("batch_predict", [](vision::detection::YOLOv5Seg& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly("preprocessor", &vision::detection::YOLOv5Seg::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::detection::YOLOv5Seg::GetPostprocessor); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/detection_pybind.cc b/fastdeploy/vision/detection/detection_pybind.cc old mode 100644 new mode 100755 index 80bdff859..b46f229ae --- a/fastdeploy/vision/detection/detection_pybind.cc +++ b/fastdeploy/vision/detection/detection_pybind.cc @@ -22,6 +22,7 @@ void BindYOLOR(pybind11::module& m); void BindYOLOv6(pybind11::module& m); void BindYOLOv5Lite(pybind11::module& m); void BindYOLOv5(pybind11::module& m); +void BindYOLOv5Seg(pybind11::module& m); void BindFastestDet(pybind11::module& m); void BindYOLOX(pybind11::module& m); void BindNanoDetPlus(pybind11::module& m); @@ -40,6 +41,7 @@ void BindDetection(pybind11::module& m) { BindYOLOv6(detection_module); BindYOLOv5Lite(detection_module); BindYOLOv5(detection_module); + BindYOLOv5Seg(detection_module); BindFastestDet(detection_module); BindYOLOX(detection_module); BindNanoDetPlus(detection_module); diff --git a/fastdeploy/vision/detection/ppdet/postprocessor.cc b/fastdeploy/vision/detection/ppdet/postprocessor.cc old mode 100644 new mode 100755 index a453c4d74..e65e5941b --- a/fastdeploy/vision/detection/ppdet/postprocessor.cc +++ b/fastdeploy/vision/detection/ppdet/postprocessor.cc @@ -32,30 +32,30 @@ bool PaddleDetPostprocessor::ProcessMask( int64_t out_mask_h = shape[1]; int64_t out_mask_w = shape[2]; int64_t out_mask_numel = shape[1] * shape[2]; - const int32_t* data = reinterpret_cast(tensor.CpuData()); + const uint8_t* data = reinterpret_cast(tensor.CpuData()); int index = 0; for (int i = 0; i < results->size(); ++i) { (*results)[i].contain_masks = true; (*results)[i].masks.resize((*results)[i].boxes.size()); for (int j = 0; j < (*results)[i].boxes.size(); ++j) { - int x1 = static_cast((*results)[i].boxes[j][0]); - int y1 = static_cast((*results)[i].boxes[j][1]); - int x2 = static_cast((*results)[i].boxes[j][2]); - int y2 = static_cast((*results)[i].boxes[j][3]); + int x1 = static_cast(round((*results)[i].boxes[j][0])); + int y1 = static_cast(round((*results)[i].boxes[j][1])); + int x2 = static_cast(round((*results)[i].boxes[j][2])); + int y2 = static_cast(round((*results)[i].boxes[j][3])); int keep_mask_h = y2 - y1; int keep_mask_w = x2 - x1; int keep_mask_numel = keep_mask_h * keep_mask_w; (*results)[i].masks[j].Resize(keep_mask_numel); (*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w}; - const int32_t* current_ptr = data + index * out_mask_numel; + const uint8_t* current_ptr = data + index * out_mask_numel; - int32_t* keep_mask_ptr = - reinterpret_cast((*results)[i].masks[j].Data()); + uint8_t* keep_mask_ptr = + reinterpret_cast((*results)[i].masks[j].Data()); for (int row = y1; row < y2; ++row) { - size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t); - const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1; - int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w; + size_t keep_nbytes_in_col = keep_mask_w * sizeof(uint8_t); + const uint8_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1; + uint8_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w; std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col); } index += 1; diff --git a/fastdeploy/vision/utils/nms.cc b/fastdeploy/vision/utils/nms.cc index 900acf84d..e206ff8a7 100644 --- a/fastdeploy/vision/utils/nms.cc +++ b/fastdeploy/vision/utils/nms.cc @@ -21,7 +21,19 @@ namespace utils { // The implementation refers to // https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/cpp/src/utils.cc -void NMS(DetectionResult* result, float iou_threshold) { +void NMS(DetectionResult* result, float iou_threshold, + std::vector* index) { + // get sorted score indices + std::vector sorted_indices; + if (index != nullptr) { + std::map> score_map; + for (size_t i = 0; i < result->scores.size(); ++i) { + score_map.insert(std::pair(result->scores[i], i)); + } + for (auto iter : score_map) { + sorted_indices.push_back(iter.second); + } + } utils::SortDetectionResult(result); std::vector area_of_boxes(result->boxes.size()); @@ -63,6 +75,9 @@ void NMS(DetectionResult* result, float iou_threshold) { result->boxes.emplace_back(backup.boxes[i]); result->scores.push_back(backup.scores[i]); result->label_ids.push_back(backup.label_ids[i]); + if (index != nullptr) { + index->push_back(sorted_indices[i]); + } } } diff --git a/fastdeploy/vision/utils/utils.h b/fastdeploy/vision/utils/utils.h index 1590922d8..c36d8d036 100644 --- a/fastdeploy/vision/utils/utils.h +++ b/fastdeploy/vision/utils/utils.h @@ -59,7 +59,8 @@ std::vector TopKIndices(const T* array, int array_size, int topk) { return res; } -void NMS(DetectionResult* output, float iou_threshold = 0.5); +void NMS(DetectionResult* output, float iou_threshold = 0.5, + std::vector* index = nullptr); void NMS(FaceDetectionResult* result, float iou_threshold = 0.5); diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc old mode 100644 new mode 100755 index 0bd2f0067..22f7581be --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -46,7 +46,7 @@ void BindVision(pybind11::module& m) { "vision::Mask pickle with invalid state!"); vision::Mask m; - m.data = t[0].cast>(); + m.data = t[0].cast>(); m.shape = t[1].cast>(); return m; diff --git a/fastdeploy/vision/visualize/detection.cc b/fastdeploy/vision/visualize/detection.cc index e8180cafe..d03c9da43 100644 --- a/fastdeploy/vision/visualize/detection.cc +++ b/fastdeploy/vision/visualize/detection.cc @@ -39,10 +39,10 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, if (result.scores[i] < score_threshold) { continue; } - int x1 = static_cast(result.boxes[i][0]); - int y1 = static_cast(result.boxes[i][1]); - int x2 = static_cast(result.boxes[i][2]); - int y2 = static_cast(result.boxes[i][3]); + int x1 = static_cast(round(result.boxes[i][0])); + int y1 = static_cast(round(result.boxes[i][1])); + int x2 = static_cast(round(result.boxes[i][2])); + int y2 = static_cast(round(result.boxes[i][3])); int box_h = y2 - y1; int box_w = x2 - x1; int c0 = color_map[3 * result.label_ids[i] + 0]; @@ -54,7 +54,7 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, if (score.size() > 4) { score = score.substr(0, 4); } - std::string text = id + "," + score; + std::string text = id + ", " + score; int font = cv::FONT_HERSHEY_SIMPLEX; cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr); cv::Point origin; @@ -68,10 +68,10 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, int mask_h = static_cast(result.masks[i].shape[0]); int mask_w = static_cast(result.masks[i].shape[1]); // non-const pointer for cv:Mat constructor - int32_t* mask_raw_data = const_cast( - static_cast(result.masks[i].Data())); + uint8_t* mask_raw_data = const_cast( + static_cast(result.masks[i].Data())); // only reference to mask data (zero copy) - cv::Mat mask(mask_h, mask_w, CV_32SC1, mask_raw_data); + cv::Mat mask(mask_h, mask_w, CV_8UC1, mask_raw_data); if ((mask_h != box_h) || (mask_w != box_w)) { cv::resize(mask, mask, cv::Size(box_w, box_h)); } @@ -79,7 +79,7 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, int mc0 = 255 - c0 >= 127 ? 255 - c0 : 127; int mc1 = 255 - c1 >= 127 ? 255 - c1 : 127; int mc2 = 255 - c2 >= 127 ? 255 - c2 : 127; - int32_t* mask_data = reinterpret_cast(mask.data); + uint8_t* mask_data = reinterpret_cast(mask.data); // inplace blending (zero copy) uchar* vis_im_data = static_cast(vis_im.data); for (size_t i = y1; i < y2; ++i) { diff --git a/python/fastdeploy/vision/detection/__init__.py b/python/fastdeploy/vision/detection/__init__.py index 70d00bcdb..cfa19bfb7 100755 --- a/python/fastdeploy/vision/detection/__init__.py +++ b/python/fastdeploy/vision/detection/__init__.py @@ -19,6 +19,7 @@ from .contrib.scaled_yolov4 import ScaledYOLOv4 from .contrib.nanodet_plus import NanoDetPlus from .contrib.yolox import YOLOX from .contrib.yolov5 import * +from .contrib.yolov5seg import * from .contrib.fastestdet import * from .contrib.yolov5lite import YOLOv5Lite from .contrib.yolov6 import YOLOv6 diff --git a/python/fastdeploy/vision/detection/contrib/yolov5seg.py b/python/fastdeploy/vision/detection/contrib/yolov5seg.py new file mode 100644 index 000000000..a7c35bf68 --- /dev/null +++ b/python/fastdeploy/vision/detection/contrib/yolov5seg.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class YOLOv5SegPreprocessor: + def __init__(self): + """Create a preprocessor for YOLOv5Seg + """ + self._preprocessor = C.vision.detection.YOLOv5SegPreprocessor() + + def run(self, input_ims): + """Preprocess input images for YOLOv5Seg + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + @property + def size(self): + """ + Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default size = [640, 640] + """ + return self._preprocessor.size + + @property + def padding_value(self): + """ + padding value for preprocessing, default [114.0, 114.0, 114.0] + """ + # padding value, size should be the same as channels + return self._preprocessor.padding_value + + @property + def is_scale_up(self): + """ + is_scale_up for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true + """ + return self._preprocessor.is_scale_up + + @property + def is_mini_pad(self): + """ + is_mini_pad for preprocessing, pad to the minimum rectange which height and width is times of stride, default false + """ + return self._preprocessor.is_mini_pad + + @property + def stride(self): + """ + stride for preprocessing, only for mini_pad mode, default 32 + """ + return self._preprocessor.stride + + @size.setter + def size(self, wh): + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._preprocessor.size = wh + + @padding_value.setter + def padding_value(self, value): + assert isinstance( + value, + list), "The value to set `padding_value` must be type of list." + self._preprocessor.padding_value = value + + @is_scale_up.setter + def is_scale_up(self, value): + assert isinstance( + value, + bool), "The value to set `is_scale_up` must be type of bool." + self._preprocessor.is_scale_up = value + + @is_mini_pad.setter + def is_mini_pad(self, value): + assert isinstance( + value, + bool), "The value to set `is_mini_pad` must be type of bool." + self._preprocessor.is_mini_pad = value + + @stride.setter + def stride(self, value): + assert isinstance( + stride, int), "The value to set `stride` must be type of int." + self._preprocessor.stride = value + + +class YOLOv5SegPostprocessor: + def __init__(self): + """Create a postprocessor for YOLOv5Seg + """ + self._postprocessor = C.vision.detection.YOLOv5SegPostprocessor() + + def run(self, runtime_results, ims_info): + """Postprocess the runtime results for YOLOv5Seg + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :param: ims_info: (list of dict)Record input_shape and output_shape + :return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results, ims_info) + + @property + def conf_threshold(self): + """ + confidence threshold for postprocessing, default is 0.25 + """ + return self._postprocessor.conf_threshold + + @property + def nms_threshold(self): + """ + nms threshold for postprocessing, default is 0.5 + """ + return self._postprocessor.nms_threshold + + @property + def multi_label(self): + """ + multi_label for postprocessing, set true for eval, default is True + """ + return self._postprocessor.multi_label + + @conf_threshold.setter + def conf_threshold(self, conf_threshold): + assert isinstance(conf_threshold, float),\ + "The value to set `conf_threshold` must be type of float." + self._postprocessor.conf_threshold = conf_threshold + + @nms_threshold.setter + def nms_threshold(self, nms_threshold): + assert isinstance(nms_threshold, float),\ + "The value to set `nms_threshold` must be type of float." + self._postprocessor.nms_threshold = nms_threshold + + @multi_label.setter + def multi_label(self, value): + assert isinstance( + value, + bool), "The value to set `multi_label` must be type of bool." + self._postprocessor.multi_label = value + + +class YOLOv5Seg(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a YOLOv5Seg model exported by YOLOv5. + + :param model_file: (str)Path of model file, e.g ./yolov5s-seg.onnx + :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(YOLOv5Seg, self).__init__(runtime_option) + + self._model = C.vision.detection.YOLOv5Seg( + model_file, params_file, self._runtime_option, model_format) + assert self.initialized, "YOLOv5Seg initialize failed." + + def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5): + """Detect an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param conf_threshold: confidence threshold for postprocessing, default is 0.25 + :param nms_iou_threshold: iou threshold for NMS, default is 0.5 + :return: DetectionResult + """ + + self.postprocessor.conf_threshold = conf_threshold + self.postprocessor.nms_threshold = nms_iou_threshold + return self._model.predict(input_image) + + def batch_predict(self, images): + """Classify a batch of input image + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of DetectionResult + """ + + return self._model.batch_predict(images) + + @property + def preprocessor(self): + """Get YOLOv5SegPreprocessor object of the loaded model + + :return YOLOv5SegPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get YOLOv5SegPostprocessor object of the loaded model + + :return YOLOv5SegPostprocessor + """ + return self._model.postprocessor diff --git a/tests/models/test_mask_rcnn.py b/tests/models/test_mask_rcnn.py index 8cd0a614e..0bc0fcc05 100755 --- a/tests/models/test_mask_rcnn.py +++ b/tests/models/test_mask_rcnn.py @@ -61,10 +61,6 @@ def test_detection_mask_rcnn(): ) < 1e-04, "There's diff in label_ids." -# result = model.predict(im1) -# with open("mask_rcnn_baseline.pkl", "wb") as f: -# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) - def test_detection_mask_rcnn1(): model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz" input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" @@ -79,18 +75,22 @@ def test_detection_mask_rcnn1(): config_file = os.path.join(model_path, "infer_cfg.yml") preprocessor = fd.vision.detection.PaddleDetPreprocessor(config_file) postprocessor = fd.vision.detection.PaddleDetPostprocessor() - + option = rc.test_option option.set_model_path(model_file, params_file) option.use_paddle_infer_backend() - runtime = fd.Runtime(option); + runtime = fd.Runtime(option) # compare diff im1 = cv2.imread("./resources/000000014439.jpg") for i in range(2): im1 = cv2.imread("./resources/000000014439.jpg") input_tensors = preprocessor.run([im1]) - output_tensors = runtime.infer({"image": input_tensors[0], "scale_factor": input_tensors[1], "im_shape": input_tensors[2]}) + output_tensors = runtime.infer({ + "image": input_tensors[0], + "scale_factor": input_tensors[1], + "im_shape": input_tensors[2] + }) results = postprocessor.run(output_tensors) result = results[0] @@ -114,6 +114,7 @@ def test_detection_mask_rcnn1(): assert diff_label_ids[scores > score_threshold].max( ) < 1e-04, "There's diff in label_ids." + if __name__ == "__main__": test_detection_mask_rcnn() test_detection_mask_rcnn1() diff --git a/tests/models/test_yolov5seg.py b/tests/models/test_yolov5seg.py new file mode 100644 index 000000000..8eb88411f --- /dev/null +++ b/tests/models/test_yolov5seg.py @@ -0,0 +1,220 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy import ModelFormat +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_yolov5seg(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5seg_result1.pkl" + result_url2 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5seg_result2.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(input_url2, "resources") + fd.download(result_url1, "resources") + fd.download(result_url2, "resources") + + model_file = "resources/yolov5s-seg.onnx" + rc.test_option.use_ort_backend() + model = fd.vision.detection.YOLOv5Seg( + model_file, runtime_option=rc.test_option) + + with open("resources/yolov5seg_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + with open("resources/yolov5seg_result2.pkl", "rb") as f: + expect2 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + im2 = cv2.imread("./resources/000000570688.jpg") + + for i in range(3): + # test single predict + result1 = model.predict(im1) + result2 = model.predict(im2) + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_boxes_2 = np.fabs( + np.array(result2.boxes) - np.array(expect2["boxes"])) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expect2["label_ids"])) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expect2["scores"])) + + # for masks + for j in range(np.array(result1.boxes).shape[0]): + result_mask_1 = np.array(result1.masks[j].data).reshape( + result1.masks[j].shape) + diff_mask_1 = np.fabs(result_mask_1 - np.array(expect1["mask_" + + str(j)])) + nonzero_nums = np.count_nonzero(diff_mask_1) + nonzero_count = nonzero_nums / (diff_mask_1.shape[0] * + diff_mask_1.shape[1]) + assert nonzero_count < 1e-02, "The different pixel ratio of mask1 is greater than 1%." + + for k in range(np.array(result2.boxes).shape[0]): + result_mask_2 = np.array(result2.masks[k].data).reshape( + result2.masks[k].shape) + diff_mask_2 = np.fabs(result_mask_2 - np.array(expect2["mask_" + + str(k)])) + nonzero_nums = np.count_nonzero(diff_mask_2) + nonzero_count = nonzero_nums / (diff_mask_2.shape[0] * + diff_mask_2.shape[1]) + assert nonzero_count < 1e-02, "The different pixel ratio of mask2 is greater than 1%." + + assert diff_boxes_1.max( + ) < 1e-01, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-02, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-04, "There's difference in detection score 1." + + assert diff_boxes_2.max( + ) < 1e-01, "There's difference in detection boxes 2." + assert diff_label_2.max( + ) < 1e-02, "There's difference in detection label 2." + assert diff_scores_2.max( + ) < 1e-04, "There's difference in detection score 2." + + # test batch predict + results = model.batch_predict([im1, im2]) + result1 = results[0] + result2 = results[1] + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_boxes_2 = np.fabs( + np.array(result2.boxes) - np.array(expect2["boxes"])) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_label_2 = np.fabs( + np.array(result2.label_ids) - np.array(expect2["label_ids"])) + + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + diff_scores_2 = np.fabs( + np.array(result2.scores) - np.array(expect2["scores"])) + + # for masks + for j in range(np.array(result1.boxes).shape[0]): + result_mask_1 = np.array(result1.masks[j].data).reshape( + result1.masks[j].shape) + diff_mask_1 = np.fabs(result_mask_1 - np.array(expect1["mask_" + + str(j)])) + nonzero_nums = np.count_nonzero(diff_mask_1) + nonzero_count = nonzero_nums / (diff_mask_1.shape[0] * + diff_mask_1.shape[1]) + assert nonzero_count < 1e-02, "The different pixel ratio of mask1 is greater than 1%." + + for k in range(np.array(result2.boxes).shape[0]): + result_mask_2 = np.array(result2.masks[k].data).reshape( + result2.masks[k].shape) + diff_mask_2 = np.fabs(result_mask_2 - np.array(expect2["mask_" + + str(k)])) + nonzero_nums = np.count_nonzero(diff_mask_2) + nonzero_count = nonzero_nums / (diff_mask_2.shape[0] * + diff_mask_2.shape[1]) + assert nonzero_count < 1e-02, "The different pixel ratio of mask2 is greater than 1%." + + assert diff_boxes_1.max( + ) < 1e-01, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-02, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-03, "There's difference in detection score 1." + + assert diff_boxes_2.max( + ) < 1e-01, "There's difference in detection boxes 2." + assert diff_label_2.max( + ) < 1e-02, "There's difference in detection label 2." + assert diff_scores_2.max( + ) < 1e-04, "There's difference in detection score 2." + + +def test_detection_yolov5seg_runtime(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5seg_result1.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url1, "resources") + + model_file = "resources/yolov5s-seg.onnx" + + preprocessor = fd.vision.detection.YOLOv5SegPreprocessor() + postprocessor = fd.vision.detection.YOLOv5SegPostprocessor() + + rc.test_option.set_model_path(model_file, model_format=ModelFormat.ONNX) + rc.test_option.use_ort_backend() + runtime = fd.Runtime(rc.test_option) + + with open("resources/yolov5seg_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + + for i in range(3): + # test runtime + input_tensors, ims_info = preprocessor.run([im1.copy()]) + output_tensors = runtime.infer({"images": input_tensors[0]}) + results = postprocessor.run(output_tensors, ims_info) + result1 = results[0] + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + + # for masks + for j in range(np.array(result1.boxes).shape[0]): + result_mask_1 = np.array(result1.masks[j].data).reshape( + result1.masks[j].shape) + diff_mask_1 = np.fabs(result_mask_1 - np.array(expect1["mask_" + + str(j)])) + nonzero_nums = np.count_nonzero(diff_mask_1) + nonzero_count = nonzero_nums / (diff_mask_1.shape[0] * + diff_mask_1.shape[1]) + assert nonzero_count < 1e-02, "The different pixel ratio of mask1 is greater than 1%." + + assert diff_boxes_1.max( + ) < 1e-01, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-02, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-04, "There's difference in detection score 1." + + +if __name__ == "__main__": + test_detection_yolov5seg() + test_detection_yolov5seg_runtime()