From e2487817847e4bebc4f326214f8a88ede6d8184e Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 22 Jul 2022 09:49:55 +0800 Subject: [PATCH] Add NanoDet-Plus Model support (#32) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * Add NanoDet-Plus Model support Co-authored-by: Jason --- examples/CMakeLists.txt | 14 + examples/vision/rangilyu_nanodet_plus.cc | 53 +++ fastdeploy/vision.h | 1 + fastdeploy/vision/__init__.py | 1 + fastdeploy/vision/megvii/__init__.py | 14 +- fastdeploy/vision/rangilyu/__init__.py | 105 ++++++ fastdeploy/vision/rangilyu/nanodet_plus.cc | 355 ++++++++++++++++++ fastdeploy/vision/rangilyu/nanodet_plus.h | 101 +++++ fastdeploy/vision/rangilyu/rangilyu_pybind.cc | 41 ++ fastdeploy/vision/vision_pybind.cc | 2 + model_zoo/vision/nanodet_plus/README.md | 46 +++ model_zoo/vision/nanodet_plus/api.md | 71 ++++ .../vision/nanodet_plus/cpp/CMakeLists.txt | 17 + model_zoo/vision/nanodet_plus/cpp/README.md | 30 ++ .../vision/nanodet_plus/cpp/nanodet_plus.cc | 40 ++ model_zoo/vision/nanodet_plus/nanodet_plus.py | 23 ++ 16 files changed, 907 insertions(+), 7 deletions(-) create mode 100644 examples/vision/rangilyu_nanodet_plus.cc create mode 100644 fastdeploy/vision/rangilyu/__init__.py create mode 100644 fastdeploy/vision/rangilyu/nanodet_plus.cc create mode 100644 fastdeploy/vision/rangilyu/nanodet_plus.h create mode 100644 fastdeploy/vision/rangilyu/rangilyu_pybind.cc create mode 100644 model_zoo/vision/nanodet_plus/README.md create mode 100644 model_zoo/vision/nanodet_plus/api.md create mode 100644 model_zoo/vision/nanodet_plus/cpp/CMakeLists.txt create mode 100644 model_zoo/vision/nanodet_plus/cpp/README.md create mode 100644 model_zoo/vision/nanodet_plus/cpp/nanodet_plus.cc create mode 100644 model_zoo/vision/nanodet_plus/nanodet_plus.py diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 112193c86..31ca40af3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + function(add_fastdeploy_executable FIELD CC_FILE) # temp target name/file var in function scope set(TEMP_TARGET_FILE ${CC_FILE}) diff --git a/examples/vision/rangilyu_nanodet_plus.cc b/examples/vision/rangilyu_nanodet_plus.cc new file mode 100644 index 000000000..91dcd604e --- /dev/null +++ b/examples/vision/rangilyu_nanodet_plus.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "../resources/models/nanodet-plus-m_320.onnx"; + std::string img_path = "../resources/images/bus.jpg"; + std::string vis_path = + "../resources/outputs/rangilyu_nanodet_plus_vis_result.jpg"; + + auto model = vis::rangilyu::NanoDetPlus(model_file); + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index d539482a7..b7836ca46 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -19,6 +19,7 @@ #include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/rangilyu/nanodet_plus.h" #include "fastdeploy/vision/ppseg/model.h" #include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/wongkinyiu/yolor.h" diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index 08b0d6812..09be1fa1b 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -22,3 +22,4 @@ from . import meituan from . import megvii from . import visualize from . import wongkinyiu +from . import rangilyu diff --git a/fastdeploy/vision/megvii/__init__.py b/fastdeploy/vision/megvii/__init__.py index 67096e4fc..8f96c9742 100644 --- a/fastdeploy/vision/megvii/__init__.py +++ b/fastdeploy/vision/megvii/__init__.py @@ -28,8 +28,8 @@ class YOLOX(FastDeployModel): # 初始化后的option保存在self._runtime_option super(YOLOX, self).__init__(runtime_option) - self._model = C.vision.megvii.YOLOX( - model_file, params_file, self._runtime_option, model_format) + self._model = C.vision.megvii.YOLOX(model_file, params_file, + self._runtime_option, model_format) # 通过self.initialized判断整个模型的初始化是否成功 assert self.initialized, "YOLOX initialize failed." @@ -53,8 +53,8 @@ class YOLOX(FastDeployModel): @property def downsample_strides(self): - return self._model.downsample_strides - + return self._model.downsample_strides + @property def max_wh(self): return self._model.max_wh @@ -78,16 +78,16 @@ class YOLOX(FastDeployModel): @is_decode_exported.setter def is_decode_exported(self, value): assert isinstance( - value, + value, bool), "The value to set `is_decode_exported` must be type of bool." - self._model.max_wh = value + self._model.is_decode_exported = value @downsample_strides.setter def downsample_strides(self, value): assert isinstance( value, list), "The value to set `downsample_strides` must be type of list." - self._model.downsample_strides = value + self._model.downsample_strides = value @max_wh.setter def max_wh(self, value): diff --git a/fastdeploy/vision/rangilyu/__init__.py b/fastdeploy/vision/rangilyu/__init__.py new file mode 100644 index 000000000..f2e8ace9f --- /dev/null +++ b/fastdeploy/vision/rangilyu/__init__.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from ... import FastDeployModel, Frontend +from ... import fastdeploy_main as C + + +class NanoDetPlus(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(NanoDetPlus, self).__init__(runtime_option) + + self._model = C.vision.rangilyu.NanoDetPlus( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "NanoDetPlus initialize failed." + + def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5): + return self._model.predict(input_image, conf_threshold, + nms_iou_threshold) + + # 一些跟NanoDetPlus模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [416, 416]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @property + def padding_value(self): + return self._model.padding_value + + @property + def keep_ratio(self): + return self._model.keep_ratio + + @property + def downsample_strides(self): + return self._model.downsample_strides + + @property + def max_wh(self): + return self._model.max_wh + + @property + def reg_max(self): + return self._model.reg_max + + @size.setter + def size(self, wh): + assert isinstance(wh, [list, tuple]),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @padding_value.setter + def padding_value(self, value): + assert isinstance( + value, + list), "The value to set `padding_value` must be type of list." + self._model.padding_value = value + + @keep_ratio.setter + def keep_ratio(self, value): + assert isinstance( + value, bool), "The value to set `keep_ratio` must be type of bool." + self._model.keep_ratio = value + + @downsample_strides.setter + def downsample_strides(self, value): + assert isinstance( + value, + list), "The value to set `downsample_strides` must be type of list." + self._model.downsample_strides = value + + @max_wh.setter + def max_wh(self, value): + assert isinstance( + value, float), "The value to set `max_wh` must be type of float." + self._model.max_wh = value + + @reg_max.setter + def reg_max(self, value): + assert isinstance( + value, int), "The value to set `reg_max` must be type of int." + self._model.reg_max = value diff --git a/fastdeploy/vision/rangilyu/nanodet_plus.cc b/fastdeploy/vision/rangilyu/nanodet_plus.cc new file mode 100644 index 000000000..678e131c4 --- /dev/null +++ b/fastdeploy/vision/rangilyu/nanodet_plus.cc @@ -0,0 +1,355 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/rangilyu/nanodet_plus.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { + +namespace vision { + +namespace rangilyu { + +struct NanoDetPlusCenterPoint { + int grid0; + int grid1; + int stride; +}; + +void GenerateNanoDetPlusCenterPoints( + const std::vector& size, const std::vector& downsample_strides, + std::vector* center_points) { + // size: tuple of input (width, height), e.g (320, 320) + // downsample_strides: downsample strides in NanoDet and + // NanoDet-Plus, e.g (8, 16, 32, 64) + const int width = size[0]; + const int height = size[1]; + for (const auto& ds : downsample_strides) { + int num_grid_w = width / ds; + int num_grid_h = height / ds; + for (int g1 = 0; g1 < num_grid_h; ++g1) { + for (int g0 = 0; g0 < num_grid_w; ++g0) { + (*center_points).emplace_back(NanoDetPlusCenterPoint{g0, g1, ds}); + } + } + } +} + +void WrapAndResize(Mat* mat, std::vector size, std::vector color, + bool keep_ratio = false) { + // Reference: nanodet/data/transform/warp.py#L139 + // size: tuple of input (width, height) + // The default value of `keep_ratio` is `fasle` in + // `config/nanodet-plus-m-1.5x_320.yml` for both + // train and val processes. So, we just let this + // option default `false` according to the official + // implementation in NanoDet and NanoDet-Plus. + // Note, this function will apply a normal resize + // operation to input Mat if the keep_ratio option + // is fasle and the behavior will be the same as + // yolov5's letterbox if keep_ratio is true. + + // with keep_ratio = false (default) + if (!keep_ratio) { + int resize_h = size[1]; + int resize_w = size[0]; + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + return; + } + // with keep_ratio = true, same as yolov5's letterbox + float r = std::min(size[1] * 1.0f / static_cast(mat->Height()), + size[0] * 1.0f / static_cast(mat->Width())); + + int resize_h = int(round(static_cast(mat->Height()) * r)); + int resize_w = int(round(static_cast(mat->Width()) * r)); + + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + + int pad_w = size[0] - resize_w; + int pad_h = size[1] - resize_h; + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, color); + } +} + +void GFLRegression(const float* logits, size_t reg_num, float* offset) { + // Hint: reg_num = reg_max + 1 + FDASSERT(((nullptr != logits) && (reg_num != 0)), + "NanoDetPlus: logits is nullptr or reg_num is 0 in GFLRegression."); + // softmax + float total_exp = 0.f; + std::vector softmax_probs(reg_num); + for (size_t i = 0; i < reg_num; ++i) { + softmax_probs[i] = std::exp(logits[i]); + total_exp += softmax_probs[i]; + } + for (size_t i = 0; i < reg_num; ++i) { + softmax_probs[i] = softmax_probs[i] / total_exp; + } + // gfl regression -> offset + for (size_t i = 0; i < reg_num; ++i) { + (*offset) += static_cast(i) * softmax_probs[i]; + } +} + +NanoDetPlus::NanoDetPlus(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 + valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool NanoDetPlus::Initialize() { + // parameters for preprocess + size = {320, 320}; + padding_value = {0.0f, 0.0f, 0.0f}; + keep_ratio = false; + downsample_strides = {8, 16, 32, 64}; + max_wh = 4096.0f; + reg_max = 7; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + // Check if the input shape is dynamic after Runtime already initialized. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + return true; +} + +bool NanoDetPlus::Preprocess( + Mat* mat, FDTensor* output, + std::map>* im_info) { + // NanoDet-Plus preprocess steps + // 1. WrapAndResize + // 2. HWC->CHW + // 3. Normalize or Convert (keep BGR order) + WrapAndResize(mat, size, padding_value, keep_ratio); + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + // Compute `result = mat * alpha + beta` directly by channel + // Reference: /config/nanodet-plus-m-1.5x_320.yml#L89 + // from mean: [103.53, 116.28, 123.675], std: [57.375, 57.12, 58.395] + // x' = (x - mean) / std to x'= x * alpha + beta. + // e.g alpha[0] = 0.017429f = 1.0f / 57.375f + // e.g beta[0] = -103.53f * 0.0174291f + std::vector alpha = {0.017429f, 0.017507f, 0.017125f}; + std::vector beta = {-103.53f * 0.0174291f, -116.28f * 0.0175070f, + -123.675f * 0.0171247f}; // BGR order + Convert::Run(mat, alpha, beta); + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool NanoDetPlus::Postprocess( + FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold) { + FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now."); + result->Clear(); + result->Reserve(infer_result.shape[1]); + if (infer_result.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + // generate center points with dowmsample strides + std::vector center_points; + GenerateNanoDetPlusCenterPoints(size, downsample_strides, ¢er_points); + + // infer_result shape might look like (1,2125,112) + const int num_cls_reg = infer_result.shape[2]; // e.g 112 + const int num_classes = num_cls_reg - (reg_max + 1) * 4; // e.g 80 + float* data = static_cast(infer_result.Data()); + for (size_t i = 0; i < infer_result.shape[1]; ++i) { + float* scores = data + i * num_cls_reg; + float* max_class_score = std::max_element(scores, scores + num_classes); + float confidence = (*max_class_score); + // filter boxes by conf_threshold + if (confidence <= conf_threshold) { + continue; + } + int32_t label_id = std::distance(scores, max_class_score); + // fetch i-th center point + float grid0 = static_cast(center_points.at(i).grid0); + float grid1 = static_cast(center_points.at(i).grid1); + float downsample_stride = static_cast(center_points.at(i).stride); + // apply gfl regression to get offsets (l,t,r,b) + float* logits = data + i * num_cls_reg + num_classes; // 32|44... + std::vector offsets(4); + for (size_t j = 0; j < 4; ++j) { + GFLRegression(logits + j * (reg_max + 1), reg_max + 1, &offsets[j]); + } + // convert from offsets to [x1, y1, x2, y2] + float l = offsets[0]; // left + float t = offsets[1]; // top + float r = offsets[2]; // right + float b = offsets[3]; // bottom + + float x1 = (grid0 - l) * downsample_stride; // cx - l x1 + float y1 = (grid1 - t) * downsample_stride; // cy - t y1 + float x2 = (grid0 + r) * downsample_stride; // cx + r x2 + float y2 = (grid1 + b) * downsample_stride; // cy + b y2 + + result->boxes.emplace_back( + std::array{x1 + label_id * max_wh, y1 + label_id * max_wh, + x2 + label_id * max_wh, y2 + label_id * max_wh}); + // label_id * max_wh for multi classes NMS + result->label_ids.push_back(label_id); + result->scores.push_back(confidence); + } + utils::NMS(result, nms_iou_threshold); + + // scale the boxes to the origin image shape + auto iter_out = im_info.find("output_shape"); + auto iter_ipt = im_info.find("input_shape"); + FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + // without keep_ratio + if (!keep_ratio) { + // x' = (x / out_w) * ipt_w = x / (out_w / ipt_w) + // y' = (y / out_h) * ipt_h = y / (out_h / ipt_h) + float r_w = out_w / ipt_w; + float r_h = out_h / ipt_h; + for (size_t i = 0; i < result->boxes.size(); ++i) { + int32_t label_id = (result->label_ids)[i]; + // clip box + result->boxes[i][0] = result->boxes[i][0] - max_wh * label_id; + result->boxes[i][1] = result->boxes[i][1] - max_wh * label_id; + result->boxes[i][2] = result->boxes[i][2] - max_wh * label_id; + result->boxes[i][3] = result->boxes[i][3] - max_wh * label_id; + result->boxes[i][0] = std::max(result->boxes[i][0] / r_w, 0.0f); + result->boxes[i][1] = std::max(result->boxes[i][1] / r_h, 0.0f); + result->boxes[i][2] = std::max(result->boxes[i][2] / r_w, 0.0f); + result->boxes[i][3] = std::max(result->boxes[i][3] / r_h, 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + return true; + } + // with keep_ratio + float r = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * r) / 2; + float pad_w = (out_w - ipt_w * r) / 2; + for (size_t i = 0; i < result->boxes.size(); ++i) { + int32_t label_id = (result->label_ids)[i]; + // clip box + result->boxes[i][0] = result->boxes[i][0] - max_wh * label_id; + result->boxes[i][1] = result->boxes[i][1] - max_wh * label_id; + result->boxes[i][2] = result->boxes[i][2] - max_wh * label_id; + result->boxes[i][3] = result->boxes[i][3] - max_wh * label_id; + result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / r, 0.0f); + result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / r, 0.0f); + result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / r, 0.0f); + result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / r, 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + return true; +} + +bool NanoDetPlus::Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold, float nms_iou_threshold) { +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_START(0) +#endif + + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(0, "Preprocess") + TIMERECORD_START(1) +#endif + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(1, "Inference") + TIMERECORD_START(2) +#endif + + if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, + nms_iou_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(2, "Postprocess") +#endif + return true; +} + +} // namespace rangilyu +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/rangilyu/nanodet_plus.h b/fastdeploy/vision/rangilyu/nanodet_plus.h new file mode 100644 index 000000000..4184aa18e --- /dev/null +++ b/fastdeploy/vision/rangilyu/nanodet_plus.h @@ -0,0 +1,101 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { + +namespace vision { + +namespace rangilyu { + +class FASTDEPLOY_DECL NanoDetPlus : public FastDeployModel { + public: + // 当model_format为ONNX时,无需指定params_file + // 当model_format为Paddle时,则需同时指定model_file & params_file + NanoDetPlus(const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + std::string ModelName() const { return "RangiLyu/nanodet"; } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold = 0.35f, + float nms_iou_threshold = 0.5f); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of input size (width, height), e.g (320, 320) + std::vector size; + // padding value, size should be same with Channels + std::vector padding_value; + // keep aspect ratio or not when perform resize operation. + // This option is set as `false` by default in NanoDet-Plus. + bool keep_ratio; + // downsample strides for NanoDet-Plus to generate anchors, will + // take (8, 16, 32, 64) as default values. + std::vector downsample_strides; + // for offseting the boxes by classes when using NMS, default 4096. + float max_wh; + // reg_max for GFL regression, default 7 + int reg_max; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + // nms_iou_threshold 后处理时NMS设定的iou阈值 + bool Postprocess(FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + // 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致 + bool IsDynamicInput() const { return is_dynamic_input_; } + + // whether to inference with dynamic shape (e.g ONNX export with dynamic shape + // or not.) + // RangiLyu/nanodet official 'export_onnx.py' script will export static ONNX + // by default. + // This value will auto check by fastdeploy after the internal Runtime + // initialized. + bool is_dynamic_input_; +}; + +} // namespace rangilyu +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/rangilyu/rangilyu_pybind.cc b/fastdeploy/vision/rangilyu/rangilyu_pybind.cc new file mode 100644 index 000000000..70bde6005 --- /dev/null +++ b/fastdeploy/vision/rangilyu/rangilyu_pybind.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindRangiLyu(pybind11::module& m) { + auto rangilyu_module = + m.def_submodule("rangilyu", "https://github.com/RangiLyu/nanodet"); + pybind11::class_( + rangilyu_module, "NanoDetPlus") + .def(pybind11::init()) + .def("predict", + [](vision::rangilyu::NanoDetPlus& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }) + .def_readwrite("size", &vision::rangilyu::NanoDetPlus::size) + .def_readwrite("padding_value", + &vision::rangilyu::NanoDetPlus::padding_value) + .def_readwrite("keep_ratio", &vision::rangilyu::NanoDetPlus::keep_ratio) + .def_readwrite("downsample_strides", + &vision::rangilyu::NanoDetPlus::downsample_strides) + .def_readwrite("max_wh", &vision::rangilyu::NanoDetPlus::max_wh) + .def_readwrite("reg_max", &vision::rangilyu::NanoDetPlus::reg_max); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 22c4f0bc2..42fcebff4 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -23,6 +23,7 @@ void BindPPSeg(pybind11::module& m); void BindUltralytics(pybind11::module& m); void BindMeituan(pybind11::module& m); void BindMegvii(pybind11::module& m); +void BindRangiLyu(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); #endif @@ -56,6 +57,7 @@ void BindVision(pybind11::module& m) { BindWongkinyiu(m); BindMeituan(m); BindMegvii(m); + BindRangiLyu(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); #endif diff --git a/model_zoo/vision/nanodet_plus/README.md b/model_zoo/vision/nanodet_plus/README.md new file mode 100644 index 000000000..164f7691f --- /dev/null +++ b/model_zoo/vision/nanodet_plus/README.md @@ -0,0 +1,46 @@ +# NanoDetPlus部署示例 + +当前支持模型版本为:[NanoDetPlus v1.0.0-alpha-1](https://github.com/RangiLyu/nanodet/releases/tag/v1.0.0-alpha-1) + +本文档说明如何进行[NanoDetPlus](https://github.com/RangiLyu/nanodet)的快速部署推理。本目录结构如下 +``` +. +├── cpp # C++ 代码目录 +│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件 +│   ├── README.md # C++ 代码编译部署文档 +│   └── nanodet_plus.cc # C++ 示例代码 +├── README.md # YOLOX 部署文档 +└── nanodet_plus.py # Python示例代码 +``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` +``` +# 安装fastdeploy-python工具 +pip install fastdeploy-python + +# 安装vision-cpu模块 +fastdeploy install vision-cpu +``` + +## Python部署 + +执行如下代码即会自动下载NanoDetPlus模型和测试图片 +``` +python nanodet_plus.py +``` + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +5.710144,220.634033, 807.854370, 724.089111, 0.825635, 5 +45.646439,393.694061, 229.267044, 903.998413, 0.818263, 0 +218.289322,402.268829, 342.083252, 861.766479, 0.709301, 0 +698.587036,325.627197, 809.000000, 876.990967, 0.630235, 0 +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [NanoDetPlus API文档](./api.md) diff --git a/model_zoo/vision/nanodet_plus/api.md b/model_zoo/vision/nanodet_plus/api.md new file mode 100644 index 000000000..b428e39df --- /dev/null +++ b/model_zoo/vision/nanodet_plus/api.md @@ -0,0 +1,71 @@ +# NanoDetPlus API说明 + +## Python API + +### NanoDetPlus类 +``` +fastdeploy.vision.rangilyu.NanoDetPlus(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX) +``` +NanoDetPlus模型加载和初始化,当model_format为`fd.Frontend.ONNX`时,只需提供model_file,如`nanodet-plus-m_320.onnx`;当model_format为`fd.Frontend.PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> NanoDetPlus.predict(image_data, conf_threshold=0.35, nms_iou_threshold=0.5) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +示例代码参考[nanodet_plus.py](./nanodet_plus.py) + + +## C++ API + +### NanoDetPlus类 +``` +fastdeploy::vision::rangilyu::NanoDetPlus( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +NanoDetPlus模型加载和初始化,当model_format为`Frontend::ONNX`时,只需提供model_file,如`nanodet-plus-m_320.onnx`;当model_format为`Frontend::PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> NanoDetPlus::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.35, +> float nms_iou_threshold = 0.5) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +示例代码参考[cpp/nanodet_plus.cc](cpp/nanodet_plus.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/nanodet_plus/cpp/CMakeLists.txt b/model_zoo/vision/nanodet_plus/cpp/CMakeLists.txt new file mode 100644 index 000000000..7a78ef9e4 --- /dev/null +++ b/model_zoo/vision/nanodet_plus/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(nanodet_plus_demo C CXX) +CMAKE_MINIMUM_REQUIRED(VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.0.3/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(nanodet_plus_demo ${PROJECT_SOURCE_DIR}/nanodet_plus.cc) +# 添加FastDeploy库依赖 +target_link_libraries(nanodet_plus_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/nanodet_plus/cpp/README.md b/model_zoo/vision/nanodet_plus/cpp/README.md new file mode 100644 index 000000000..03dc65a0a --- /dev/null +++ b/model_zoo/vision/nanodet_plus/cpp/README.md @@ -0,0 +1,30 @@ +# 编译NanoDetPlus示例 + +当前支持模型版本为:[NanoDetPlus v1.0.0-alpha-1](https://github.com/RangiLyu/nanodet/releases/tag/v1.0.0-alpha-1) + +``` +# 下载和解压预测库 +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz + +# 编译示例代码 +mkdir build & cd build +cmake .. +make -j + +# 下载模型和图片 +wget https://github.com/RangiLyu/nanodet/releases/download/v1.0.0-alpha-1/nanodet-plus-m_320.onnx +wget https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg + +# 执行 +./nanodet_plus_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +5.710144,220.634033, 807.854370, 724.089111, 0.825635, 5 +45.646439,393.694061, 229.267044, 903.998413, 0.818263, 0 +218.289322,402.268829, 342.083252, 861.766479, 0.709301, 0 +698.587036,325.627197, 809.000000, 876.990967, 0.630235, 0 +``` diff --git a/model_zoo/vision/nanodet_plus/cpp/nanodet_plus.cc b/model_zoo/vision/nanodet_plus/cpp/nanodet_plus.cc new file mode 100644 index 000000000..b252bf6f8 --- /dev/null +++ b/model_zoo/vision/nanodet_plus/cpp/nanodet_plus.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + auto model = vis::rangilyu::NanoDetPlus("nanodet-plus-m_320.onnx"); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + cv::Mat im = cv::imread("bus.jpg"); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite("vis_result.jpg", vis_im); + return 0; +} diff --git a/model_zoo/vision/nanodet_plus/nanodet_plus.py b/model_zoo/vision/nanodet_plus/nanodet_plus.py new file mode 100644 index 000000000..4101d2040 --- /dev/null +++ b/model_zoo/vision/nanodet_plus/nanodet_plus.py @@ -0,0 +1,23 @@ +import fastdeploy as fd +import cv2 + +# 下载模型和测试图片 +model_url = "https://github.com/RangiLyu/nanodet/releases/download/v1.0.0-alpha-1/nanodet-plus-m_320.onnx" +test_jpg_url = "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg" +fd.download(model_url, ".", show_progress=True) +fd.download(test_jpg_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.rangilyu.NanoDetPlus("nanodet-plus-m_320.onnx") + +# 预测图片 +im = cv2.imread("bus.jpg") +result = model.predict(im, conf_threshold=0.35, nms_iou_threshold=0.5) + +# 可视化结果 +fd.vision.visualize.vis_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) +print(model.runtime_option)