diff --git a/examples/vision/detection/fastestdet/cpp/CMakeLists.txt b/examples/vision/detection/fastestdet/cpp/CMakeLists.txt new file mode 100644 index 000000000..9ba668762 --- /dev/null +++ b/examples/vision/detection/fastestdet/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# Specifies the path to the fastdeploy library after you have downloaded it +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# Include the FastDeploy dependency header file +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# Add the FastDeploy library dependency +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/fastestdet/cpp/README.md b/examples/vision/detection/fastestdet/cpp/README.md new file mode 100644 index 000000000..bf2d01394 --- /dev/null +++ b/examples/vision/detection/fastestdet/cpp/README.md @@ -0,0 +1,87 @@ +# FastestDet C++部署示例 + +本目录下提供`infer.cc`快速完成FastestDet在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试 + +```bash +mkdir build +cd build +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.3.tgz +tar xvf fastdeploy-linux-x64-1.0.3.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-1.0.3 +make -j + +#下载官方转换好的FastestDet模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/FastestDet.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + + +# CPU推理 +./infer_demo FastestDet.onnx 000000014439.jpg 0 +# GPU推理 +./infer_demo FastestDet.onnx 000000014439.jpg 1 +# GPU上TensorRT推理 +./infer_demo FastestDet.onnx 000000014439.jpg 2 +``` + +运行完成可视化结果如下图所示 + + + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## FastestDet C++接口 + +### FastestDet类 + +```c++ +fastdeploy::vision::detection::FastestDet( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX) +``` + +FastestDet模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ```c++ +> FastestDet::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.65, +> float nms_iou_threshold = 0.45) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +### 类成员变量 +#### 预处理参数 +用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 + +> > * **size**(vector<int>): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[352, 352] + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/detection/fastestdet/cpp/infer.cc b/examples/vision/detection/fastestdet/cpp/infer.cc new file mode 100644 index 000000000..71dd862a2 --- /dev/null +++ b/examples/vision/detection/fastestdet/cpp/infer.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::detection::FastestDet(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::FastestDet(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 352, 352}); + auto model = fastdeploy::vision::detection::FastestDet(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::VisDetection(im, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./FastestDet.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/fastestdet/python/README.md b/examples/vision/detection/fastestdet/python/README.md new file mode 100644 index 000000000..000bf05cc --- /dev/null +++ b/examples/vision/detection/fastestdet/python/README.md @@ -0,0 +1,74 @@ +# FastestDet Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成FastestDet在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd examples/vision/detection/fastestdet/python/ + +#下载fastestdet模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/FastestDet.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# CPU推理 +python infer.py --model FastestDet.onnx --image 000000014439.jpg --device cpu +# GPU推理 +python infer.py --model FastestDet.onnx --image 000000014439.jpg --device gpu +# GPU上使用TensorRT推理 +python infer.py --model FastestDet.onnx --image 000000014439.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 + + + +## FastestDet Python接口 + +```python +fastdeploy.vision.detection.FastestDet(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX) +``` + +FastestDet模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX + +### predict函数 + +> ```python +> FastestDet.predict(image_data) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 + +> **返回** +> +> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +### 类成员属性 +#### 预处理参数 +用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 + +> > * **size**(list[int]): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[352, 352] + + +## 其它文档 + +- [FastestDet 模型介绍](..) +- [FastestDet C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/detection/fastestdet/python/infer.py b/examples/vision/detection/fastestdet/python/infer.py new file mode 100644 index 000000000..ad734b4d7 --- /dev/null +++ b/examples/vision/detection/fastestdet/python/infer.py @@ -0,0 +1,51 @@ +import fastdeploy as fd +import cv2 + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of FastestDet onnx model.") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("images", [1, 3, 352, 352]) + return option + + +args = parse_arguments() + +# Configure runtime and load model +runtime_option = build_option(args) +model = fd.vision.detection.FastestDet(args.model, runtime_option=runtime_option) + +# Predict picture detection results +im = cv2.imread(args.image) +result = model.predict(im) + +# Visualization of prediction results +vis_im = fd.vision.vis_detection(im, result) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index f5e4d0624..ef2fc90a6 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -22,6 +22,7 @@ #include "fastdeploy/vision/detection/contrib/scaledyolov4.h" #include "fastdeploy/vision/detection/contrib/yolor.h" #include "fastdeploy/vision/detection/contrib/yolov5/yolov5.h" +#include "fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h" #include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov6.h" #include "fastdeploy/vision/detection/contrib/yolov7/yolov7.h" diff --git a/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.cc b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.cc new file mode 100644 index 000000000..2bef9f38b --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +FastestDet::FastestDet(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool FastestDet::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool FastestDet::Predict(const cv::Mat& im, DetectionResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool FastestDet::BatchPredict(const std::vector& images, std::vector* results) { + std::vector>> ims_info; + std::vector fd_images = WrapMat(images); + + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; + } + + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)) { + FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; + return false; + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h new file mode 100644 index 000000000..9bd6e07df --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/detection/contrib/fastestdet/preprocessor.h" +#include "fastdeploy/vision/detection/contrib/fastestdet/postprocessor.h" + +namespace fastdeploy { +namespace vision { +namespace detection { +/*! @brief FastestDet model object used when to load a FastestDet model exported by FastestDet. + */ +class FASTDEPLOY_DECL FastestDet : public FastDeployModel { + public: + /** \brief Set path of model file and the configuration of runtime. + * + * \param[in] model_file Path of model file, e.g ./fastestdet.onnx + * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends" + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + FastestDet(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + std::string ModelName() const { return "fastestdet"; } + + /** \brief Predict the detection result for an input image + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& img, DetectionResult* result); + + /** \brief Predict the detection results for a batch of input images + * + * \param[in] imgs, The input image list, each element comes from cv::imread() + * \param[in] results The output detection result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + /// Get preprocessor reference of FastestDet + virtual FastestDetPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of FastestDet + virtual FastestDetPostprocessor& GetPostprocessor() { + return postprocessor_; + } + + protected: + bool Initialize(); + FastestDetPreprocessor preprocessor_; + FastestDetPostprocessor postprocessor_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/fastestdet_pybind.cc b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet_pybind.cc new file mode 100644 index 000000000..4ed494134 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/fastestdet_pybind.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindFastestDet(pybind11::module& m) { + pybind11::class_( + m, "FastestDetPreprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::FastestDetPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + std::vector>> ims_info; + if (!self.Run(&images, &outputs, &ims_info)) { + throw std::runtime_error("raise Exception('Failed to preprocess the input data in FastestDetPreprocessor.')"); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return make_pair(outputs, ims_info); + }) + .def_property("size", &vision::detection::FastestDetPreprocessor::GetSize, &vision::detection::FastestDetPreprocessor::SetSize); + + pybind11::class_( + m, "FastestDetPostprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::FastestDetPostprocessor& self, std::vector& inputs, + const std::vector>>& ims_info) { + std::vector results; + if (!self.Run(inputs, &results, ims_info)) { + throw std::runtime_error("raise Exception('Failed to postprocess the runtime result in FastestDetPostprocessor.')"); + } + return results; + }) + .def("run", [](vision::detection::FastestDetPostprocessor& self, std::vector& input_array, + const std::vector>>& ims_info) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, ims_info)) { + throw std::runtime_error("raise Exception('Failed to postprocess the runtime result in FastestDetPostprocessor.')"); + } + return results; + }) + .def_property("conf_threshold", &vision::detection::FastestDetPostprocessor::GetConfThreshold, &vision::detection::FastestDetPostprocessor::SetConfThreshold) + .def_property("nms_threshold", &vision::detection::FastestDetPostprocessor::GetNMSThreshold, &vision::detection::FastestDetPostprocessor::SetNMSThreshold); + + pybind11::class_(m, "FastestDet") + .def(pybind11::init()) + .def("predict", + [](vision::detection::FastestDet& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(mat, &res); + return res; + }) + .def("batch_predict", [](vision::detection::FastestDet& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly("preprocessor", &vision::detection::FastestDet::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::detection::FastestDet::GetPostprocessor); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.cc b/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.cc new file mode 100644 index 000000000..447a16c8a --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/fastestdet/postprocessor.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +FastestDetPostprocessor::FastestDetPostprocessor() { + conf_threshold_ = 0.65; + nms_threshold_ = 0.45; +} +float FastestDetPostprocessor::Sigmoid(float x) { + return 1.0f / (1.0f + exp(-x)); +} + +float FastestDetPostprocessor::Tanh(float x) { + return 2.0f / (1.0f + exp(-2 * x)) - 1; +} + +bool FastestDetPostprocessor::Run( + const std::vector &tensors, std::vector *results, + const std::vector>> &ims_info) { + int batch = 1; + + results->resize(batch); + + for (size_t bs = 0; bs < batch; ++bs) { + + (*results)[bs].Clear(); + // output (1,85,22,22) CHW + const float* output = reinterpret_cast(tensors[0].Data()) + bs * tensors[0].shape[1] * tensors[0].shape[2] * tensors[0].shape[3]; + int output_h = tensors[0].shape[2]; // out map height + int output_w = tensors[0].shape[3]; // out map weight + auto iter_out = ims_info[bs].find("output_shape"); + auto iter_ipt = ims_info[bs].find("input_shape"); + FDASSERT(iter_out != ims_info[bs].end() && iter_ipt != ims_info[bs].end(), + "Cannot find input_shape or output_shape from im_info."); + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + + // handle output boxes from out map + for (int h = 0; h < output_h; h++) { + for (int w = 0; w < output_w; w++) { + // object score + int obj_score_index = (h * output_w) + w; + float obj_score = output[obj_score_index]; + + // find max class + int category = 0; + float max_score = 0.0f; + int class_num = tensors[0].shape[1]-5; + for (size_t i = 0; i < class_num; i++) { + obj_score_index =((5 + i) * output_h * output_w) + (h * output_w) + w; + float cls_score = output[obj_score_index]; + if (cls_score > max_score) { + max_score = cls_score; + category = i; + } + } + float score = pow(max_score, 0.4) * pow(obj_score, 0.6); + + // score threshold + if (score <= conf_threshold_) { + continue; + } + if (score > conf_threshold_) { + // handle box x y w h + int x_offset_index = (1 * output_h * output_w) + (h * output_w) + w; + int y_offset_index = (2 * output_h * output_w) + (h * output_w) + w; + int box_width_index = (3 * output_h * output_w) + (h * output_w) + w; + int box_height_index = (4 * output_h * output_w) + (h * output_w) + w; + + float x_offset = Tanh(output[x_offset_index]); + float y_offset = Tanh(output[y_offset_index]); + float box_width = Sigmoid(output[box_width_index]); + float box_height = Sigmoid(output[box_height_index]); + + float cx = (w + x_offset) / output_w; + float cy = (h + y_offset) / output_h; + + // convert from [x, y, w, h] to [x1, y1, x2, y2] + (*results)[bs].boxes.emplace_back(std::array{ + cx - box_width / 2.0f, + cy - box_height / 2.0f, + cx + box_width / 2.0f, + cy + box_height / 2.0f}); + (*results)[bs].label_ids.push_back(category); + (*results)[bs].scores.push_back(score); + } + } + } + if ((*results)[bs].boxes.size() == 0) { + return true; + } + + // scale boxes to origin shape + for (size_t i = 0; i < (*results)[bs].boxes.size(); ++i) { + (*results)[bs].boxes[i][0] = ((*results)[bs].boxes[i][0]) * ipt_w; + (*results)[bs].boxes[i][1] = ((*results)[bs].boxes[i][1]) * ipt_h; + (*results)[bs].boxes[i][2] = ((*results)[bs].boxes[i][2]) * ipt_w; + (*results)[bs].boxes[i][3] = ((*results)[bs].boxes[i][3]) * ipt_h; + } + //NMS + utils::NMS(&((*results)[bs]), nms_threshold_); + //clip box + for (size_t i = 0; i < (*results)[bs].boxes.size(); ++i) { + (*results)[bs].boxes[i][0] = std::max((*results)[bs].boxes[i][0], 0.0f); + (*results)[bs].boxes[i][1] = std::max((*results)[bs].boxes[i][1], 0.0f); + (*results)[bs].boxes[i][2] = std::min((*results)[bs].boxes[i][2], ipt_w); + (*results)[bs].boxes[i][3] = std::min((*results)[bs].boxes[i][3], ipt_h); + } + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.h b/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.h new file mode 100644 index 000000000..c576aee20 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/postprocessor.h @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Postprocessor object for FastestDet serials model. + */ +class FASTDEPLOY_DECL FastestDetPostprocessor { + public: + /** \brief Create a postprocessor instance for FastestDet serials model + */ + FastestDetPostprocessor(); + + /** \brief Process the result of runtime and fill to DetectionResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of detection + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* results, + const std::vector>>& ims_info); + + /// Set conf_threshold, default 0.65 + void SetConfThreshold(const float& conf_threshold) { + conf_threshold_ = conf_threshold; + } + + /// Get conf_threshold, default 0.65 + float GetConfThreshold() const { return conf_threshold_; } + + /// Set nms_threshold, default 0.45 + void SetNMSThreshold(const float& nms_threshold) { + nms_threshold_ = nms_threshold; + } + + /// Get nms_threshold, default 0.45 + float GetNMSThreshold() const { return nms_threshold_; } + + protected: + float conf_threshold_; + float nms_threshold_; + float Sigmoid(float x); + float Tanh(float x); +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.cc b/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.cc new file mode 100644 index 000000000..f4ff11e8f --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/fastestdet/preprocessor.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +FastestDetPreprocessor::FastestDetPreprocessor() { + size_ = {352, 352}; //{h,w} +} + +bool FastestDetPreprocessor::Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info) { + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + // process after image load + double ratio = (size_[0] * 1.0) / std::max(static_cast(mat->Height()), + static_cast(mat->Width())); + + // fastestdet's preprocess steps + // 1. resize + // 2. convert_and_permute(swap_rb=false) + Resize::Run(mat, size_[0], size_[1]); //resize + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + //convert to float and HWC2CHW + ConvertAndPermute::Run(mat, alpha, beta, false); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + mat->ShareWithTensor(output); + output->ExpandDim(0); // reshape to n, h, w, c + return true; +} + +bool FastestDetPreprocessor::Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info) { + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + ims_info->resize(images->size()); + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + if (!Preprocess(&(*images)[i], &tensors[i], &(*ims_info)[i])) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } + + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.h b/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.h new file mode 100644 index 000000000..8166f6198 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/fastestdet/preprocessor.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Preprocessor object for FastestDet serials model. + */ +class FASTDEPLOY_DECL FastestDetPreprocessor { + public: + /** \brief Create a preprocessor instance for FastestDet serials model + */ + FastestDetPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \param[in] ims_info The shape info list, record input_shape and output_shape + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs, + std::vector>>* ims_info); + + /// Set target size, tuple of (width, height), default size = {352, 352} + void SetSize(const std::vector& size) { size_ = size; } + + /// Get target size, tuple of (width, height), default size = {352, 352} + std::vector GetSize() const { return size_; } + + protected: + bool Preprocess(FDMat* mat, FDTensor* output, + std::map>* im_info); + + // target size, tuple of (width, height), default size = {352, 352} + std::vector size_; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/detection_pybind.cc b/fastdeploy/vision/detection/detection_pybind.cc index 9d585e18c..80bdff859 100644 --- a/fastdeploy/vision/detection/detection_pybind.cc +++ b/fastdeploy/vision/detection/detection_pybind.cc @@ -22,6 +22,7 @@ void BindYOLOR(pybind11::module& m); void BindYOLOv6(pybind11::module& m); void BindYOLOv5Lite(pybind11::module& m); void BindYOLOv5(pybind11::module& m); +void BindFastestDet(pybind11::module& m); void BindYOLOX(pybind11::module& m); void BindNanoDetPlus(pybind11::module& m); void BindPPDet(pybind11::module& m); @@ -39,6 +40,7 @@ void BindDetection(pybind11::module& m) { BindYOLOv6(detection_module); BindYOLOv5Lite(detection_module); BindYOLOv5(detection_module); + BindFastestDet(detection_module); BindYOLOX(detection_module); BindNanoDetPlus(detection_module); BindYOLOv7End2EndTRT(detection_module); diff --git a/python/fastdeploy/vision/detection/__init__.py b/python/fastdeploy/vision/detection/__init__.py index afd1cd8ce..70d00bcdb 100755 --- a/python/fastdeploy/vision/detection/__init__.py +++ b/python/fastdeploy/vision/detection/__init__.py @@ -19,6 +19,7 @@ from .contrib.scaled_yolov4 import ScaledYOLOv4 from .contrib.nanodet_plus import NanoDetPlus from .contrib.yolox import YOLOX from .contrib.yolov5 import * +from .contrib.fastestdet import * from .contrib.yolov5lite import YOLOv5Lite from .contrib.yolov6 import YOLOv6 from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT diff --git a/python/fastdeploy/vision/detection/contrib/fastestdet.py b/python/fastdeploy/vision/detection/contrib/fastestdet.py new file mode 100644 index 000000000..2f11ed43d --- /dev/null +++ b/python/fastdeploy/vision/detection/contrib/fastestdet.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class FastestDetPreprocessor: + def __init__(self): + """Create a preprocessor for FastestDet + """ + self._preprocessor = C.vision.detection.FastestDetPreprocessor() + + def run(self, input_ims): + """Preprocess input images for FastestDet + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + @property + def size(self): + """ + Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default size = [352, 352] + """ + return self._preprocessor.size + + @size.setter + def size(self, wh): + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._preprocessor.size = wh + + +class FastestDetPostprocessor: + def __init__(self): + """Create a postprocessor for FastestDet + """ + self._postprocessor = C.vision.detection.FastestDetPostprocessor() + + def run(self, runtime_results, ims_info): + """Postprocess the runtime results for FastestDet + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :param: ims_info: (list of dict)Record input_shape and output_shape + :return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results, ims_info) + + @property + def conf_threshold(self): + """ + confidence threshold for postprocessing, default is 0.65 + """ + return self._postprocessor.conf_threshold + + @property + def nms_threshold(self): + """ + nms threshold for postprocessing, default is 0.45 + """ + return self._postprocessor.nms_threshold + + @conf_threshold.setter + def conf_threshold(self, conf_threshold): + assert isinstance(conf_threshold, float),\ + "The value to set `conf_threshold` must be type of float." + self._postprocessor.conf_threshold = conf_threshold + + @nms_threshold.setter + def nms_threshold(self, nms_threshold): + assert isinstance(nms_threshold, float),\ + "The value to set `nms_threshold` must be type of float." + self._postprocessor.nms_threshold = nms_threshold + + +class FastestDet(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a FastestDet model exported by FastestDet. + + :param model_file: (str)Path of model file, e.g ./FastestDet.onnx + :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + + super(FastestDet, self).__init__(runtime_option) + + assert model_format == ModelFormat.ONNX, "FastestDet only support model format of ModelFormat.ONNX now." + self._model = C.vision.detection.FastestDet( + model_file, params_file, self._runtime_option, model_format) + + assert self.initialized, "FastestDet initialize failed." + + def predict(self, input_image): + """Detect an input image + + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: DetectionResult + """ + assert input_image is not None, "Input image is None." + return self._model.predict(input_image) + + def batch_predict(self, images): + assert len(images) == 1,"FastestDet is only support 1 image in batch_predict" + """Classify a batch of input image + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of DetectionResult + """ + + return self._model.batch_predict(images) + + @property + def preprocessor(self): + """Get FastestDetPreprocessor object of the loaded model + + :return FastestDetPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get FastestDetPostprocessor object of the loaded model + + :return FastestDetPostprocessor + """ + return self._model.postprocessor diff --git a/tests/models/test_fastestdet.py b/tests/models/test_fastestdet.py new file mode 100644 index 000000000..0934b173a --- /dev/null +++ b/tests/models/test_fastestdet.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastdeploy import ModelFormat +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_fastestdet(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/FastestDet.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/fastestdet_result1.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(input_url2, "resources") + fd.download(result_url1, "resources") + + model_file = "resources/FastestDet.onnx" + model = fd.vision.detection.FastestDet( + model_file, runtime_option=rc.test_option) + + with open("resources/fastestdet_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + print(expect1) + for i in range(3): + # test single predict + result1 = model.predict(im1) + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + + print(diff_boxes_1.max(), diff_boxes_1.mean()) + assert diff_boxes_1.max( + ) < 1e-04, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-04, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in detection score 1." + +def test_detection_fastestdet_runtime(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/FastestDet.onnx" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/fastestdet_result1.pkl" + fd.download(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url1, "resources") + + model_file = "resources/FastestDet.onnx" + + preprocessor = fd.vision.detection.FastestDetPreprocessor() + postprocessor = fd.vision.detection.FastestDetPostprocessor() + + rc.test_option.set_model_path(model_file, model_format=ModelFormat.ONNX) + rc.test_option.use_openvino_backend() + runtime = fd.Runtime(rc.test_option) + + with open("resources/fastestdet_result1.pkl", "rb") as f: + expect1 = pickle.load(f) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + + for i in range(3): + # test runtime + input_tensors, ims_info = preprocessor.run([im1.copy()]) + output_tensors = runtime.infer({"input.1": input_tensors[0]}) + results = postprocessor.run(output_tensors, ims_info) + result1 = results[0] + + diff_boxes_1 = np.fabs( + np.array(result1.boxes) - np.array(expect1["boxes"])) + diff_label_1 = np.fabs( + np.array(result1.label_ids) - np.array(expect1["label_ids"])) + diff_scores_1 = np.fabs( + np.array(result1.scores) - np.array(expect1["scores"])) + + assert diff_boxes_1.max( + ) < 1e-04, "There's difference in detection boxes 1." + assert diff_label_1.max( + ) < 1e-04, "There's difference in detection label 1." + assert diff_scores_1.max( + ) < 1e-05, "There's difference in detection score 1." + + +if __name__ == "__main__": + test_detection_fastestdet() + test_detection_fastestdet_runtime() \ No newline at end of file