From 17e4dc6b5e6acd2a4521e899b60a03a4e6fd1e00 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 26 Jul 2022 11:16:01 +0800 Subject: [PATCH] Support remove multiclass_nms to enable ppyoloe to tensorrt (#40) * Add custom operator for onnxruntime ans fix paddle backend * Polish cmake files and runtime apis * Remove copy libraries * fix some issue * fix bug * fix bug * Support remove multiclass_nms to enable paddledetection run tensorrt * Support remove multiclass_nms to enable paddledetection run tensorrt * Support remove multiclass_nms to enable paddledetection run tensorrt * Support remove multiclass_nms to enable paddledetection run tensorrt * add common operator multiclassnms * fix compile problem Co-authored-by: root --- external/onnxruntime.cmake | 2 +- external/paddle2onnx.cmake | 2 +- fastdeploy/backends/backend.h | 4 +- fastdeploy/backends/common/multiclass_nms.cc | 224 ++++++++++++++++++ fastdeploy/backends/common/multiclass_nms.h | 45 ++++ fastdeploy/backends/ort/ops/multiclass_nms.cc | 3 - fastdeploy/backends/ort/ort_backend.cc | 14 +- fastdeploy/backends/ort/ort_backend.h | 4 + fastdeploy/backends/tensorrt/trt_backend.cc | 27 ++- fastdeploy/backends/tensorrt/trt_backend.h | 4 + fastdeploy/fastdeploy_runtime.cc | 13 +- fastdeploy/fastdeploy_runtime.h | 6 + fastdeploy/utils/utils.cc | 15 ++ fastdeploy/utils/utils.h | 3 + fastdeploy/vision/ppdet/__init__.py | 5 +- fastdeploy/vision/ppdet/ppdet_pybind.cc | 5 +- fastdeploy/vision/ppdet/ppyoloe.cc | 124 +++++++--- fastdeploy/vision/ppdet/ppyoloe.h | 16 +- 18 files changed, 460 insertions(+), 56 deletions(-) create mode 100644 fastdeploy/backends/common/multiclass_nms.cc create mode 100644 fastdeploy/backends/common/multiclass_nms.h diff --git a/external/onnxruntime.cmake b/external/onnxruntime.cmake index 01e7b8126..da2ce4368 100644 --- a/external/onnxruntime.cmake +++ b/external/onnxruntime.cmake @@ -27,7 +27,7 @@ set(ONNXRUNTIME_LIB_DIR CACHE PATH "onnxruntime lib directory." FORCE) set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}" "${ONNXRUNTIME_LIB_DIR}") -set(ONNXRUNTIME_VERSION "1.11.1") +set(ONNXRUNTIME_VERSION "1.12.0") set(ONNXRUNTIME_URL_PREFIX "https://bj.bcebos.com/paddle2onnx/libs/") if(WIN32) diff --git a/external/paddle2onnx.cmake b/external/paddle2onnx.cmake index fe1fcc8a8..97ba169ac 100644 --- a/external/paddle2onnx.cmake +++ b/external/paddle2onnx.cmake @@ -43,7 +43,7 @@ else() endif(WIN32) set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/") -set(PADDLE2ONNX_VERSION "1.0.0rc1") +set(PADDLE2ONNX_VERSION "1.0.0rc2") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") elseif(APPLE) diff --git a/fastdeploy/backends/backend.h b/fastdeploy/backends/backend.h index 240f40734..071b47eb1 100644 --- a/fastdeploy/backends/backend.h +++ b/fastdeploy/backends/backend.h @@ -18,7 +18,7 @@ #include #include #include - +#include "fastdeploy/backends/common/multiclass_nms.h" #include "fastdeploy/core/fd_tensor.h" namespace fastdeploy { @@ -45,4 +45,4 @@ class BaseBackend { std::vector* outputs) = 0; }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/fastdeploy/backends/common/multiclass_nms.cc b/fastdeploy/backends/common/multiclass_nms.cc new file mode 100644 index 000000000..c3d65ec7d --- /dev/null +++ b/fastdeploy/backends/common/multiclass_nms.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/backends/common/multiclass_nms.h" +#include +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/utils/utils.h" + +namespace fastdeploy { +namespace backend { +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +void GetMaxScoreIndex(const float* scores, const int& score_size, + const float& threshold, const int& top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < score_size; ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +float BBoxArea(const float* box, const bool& normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return 0.f; + } else { + const float w = box[2] - box[0]; + const float h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +float JaccardOverlap(const float* box1, const float* box2, + const bool& normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return 0.f; + } else { + const float inter_xmin = std::max(box1[0], box2[0]); + const float inter_ymin = std::max(box1[1], box2[1]); + const float inter_xmax = std::min(box1[2], box2[2]); + const float inter_ymax = std::min(box1[3], box2[3]); + float norm = normalized ? 0.0f : 1.0f; + float inter_w = inter_xmax - inter_xmin + norm; + float inter_h = inter_ymax - inter_ymin + norm; + const float inter_area = inter_w * inter_h; + const float bbox1_area = BBoxArea(box1, normalized); + const float bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +void MultiClassNMS::FastNMS(const float* boxes, const float* scores, + const int& num_boxes, + std::vector* keep_indices) { + std::vector> sorted_indices; + GetMaxScoreIndex(scores, num_boxes, score_threshold, nms_top_k, + &sorted_indices); + + float adaptive_threshold = nms_threshold; + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < keep_indices->size(); ++k) { + if (!keep) { + break; + } + const int kept_idx = (*keep_indices)[k]; + float overlap = + JaccardOverlap(boxes + idx * 4, boxes + kept_idx * 4, normalized); + keep = overlap <= adaptive_threshold; + } + if (keep) { + keep_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) { + adaptive_threshold *= nms_eta; + } + } +} + +int MultiClassNMS::NMSForEachSample( + const float* boxes, const float* scores, int num_boxes, int num_classes, + std::map>* keep_indices) { + for (int i = 0; i < num_classes; ++i) { + if (i == background_label) { + continue; + } + const float* score_for_class_i = scores + i * num_boxes; + FastNMS(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i])); + } + int num_det = 0; + for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) { + num_det += iter->second.size(); + } + + if (keep_top_k > -1 && num_det > keep_top_k) { + std::vector>> score_index_pairs; + for (const auto& it : *keep_indices) { + int label = it.first; + const float* current_score = scores + label * num_boxes; + auto& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + score_index_pairs.push_back( + std::make_pair(current_score[idx], std::make_pair(label, idx))); + } + } + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + std::map> new_indices; + for (size_t j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + new_indices.swap(*keep_indices); + num_det = keep_top_k; + } + return num_det; +} + +void MultiClassNMS::Compute(const float* boxes_data, const float* scores_data, + const std::vector& boxes_dim, + const std::vector& scores_dim) { + int score_size = scores_dim.size(); + + int64_t batch_size = scores_dim[0]; + int64_t box_dim = boxes_dim[2]; + int64_t out_dim = box_dim + 2; + + int num_nmsed_out = 0; + FDASSERT(score_size == 3, "Require rank of input scores be 3, but now it's " + + std::to_string(score_size) + "."); + FDASSERT(boxes_dim[2] == 4, + "Require the 3-dimension of input boxes be 4, but now it's " + + std::to_string(boxes_dim[2]) + "."); + out_num_rois_data.resize(batch_size); + + std::vector>> all_indices; + for (size_t i = 0; i < batch_size; ++i) { + std::map> indices; // indices kept for each class + const float* current_boxes_ptr = + boxes_data + i * boxes_dim[1] * boxes_dim[2]; + const float* current_scores_ptr = + scores_data + i * scores_dim[1] * scores_dim[2]; + int num = NMSForEachSample(current_boxes_ptr, current_scores_ptr, + boxes_dim[1], scores_dim[1], &indices); + num_nmsed_out += num; + out_num_rois_data[i] = num; + all_indices.emplace_back(indices); + } + std::vector out_box_dims = {num_nmsed_out, 6}; + std::vector out_index_dims = {num_nmsed_out, 1}; + if (num_nmsed_out == 0) { + for (size_t i = 0; i < batch_size; ++i) { + out_num_rois_data[i] = 0; + } + return; + } + out_box_data.resize(num_nmsed_out * 6); + out_index_data.resize(num_nmsed_out); + + int count = 0; + for (size_t i = 0; i < batch_size; ++i) { + const float* current_boxes_ptr = + boxes_data + i * boxes_dim[1] * boxes_dim[2]; + const float* current_scores_ptr = + scores_data + i * scores_dim[1] * scores_dim[2]; + for (const auto& it : all_indices[i]) { + int label = it.first; + const auto& indices = it.second; + const float* current_scores_class_ptr = + current_scores_ptr + label * scores_dim[2]; + for (size_t j = 0; j < indices.size(); ++j) { + int start = count * 6; + out_box_data[start] = label; + out_box_data[start + 1] = current_scores_class_ptr[indices[j]]; + + out_box_data[start + 2] = current_boxes_ptr[indices[j] * 4]; + out_box_data[start + 3] = current_boxes_ptr[indices[j] * 4 + 1]; + out_box_data[start + 4] = current_boxes_ptr[indices[j] * 4 + 2]; + + out_box_data[start + 5] = current_boxes_ptr[indices[j] * 4 + 3]; + out_index_data[count] = i * boxes_dim[1] + indices[j]; + count += 1; + } + } + } +} +} // namespace backend +} // namespace fastdeploy diff --git a/fastdeploy/backends/common/multiclass_nms.h b/fastdeploy/backends/common/multiclass_nms.h new file mode 100644 index 000000000..48a3d9336 --- /dev/null +++ b/fastdeploy/backends/common/multiclass_nms.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace fastdeploy { +namespace backend { +struct MultiClassNMS { + int64_t background_label = -1; + int64_t keep_top_k = -1; + float nms_eta; + float nms_threshold = 0.7; + int64_t nms_top_k; + bool normalized; + float score_threshold; + + std::vector out_num_rois_data; + std::vector out_index_data; + std::vector out_box_data; + void FastNMS(const float* boxes, const float* scores, const int& num_boxes, + std::vector* keep_indices); + int NMSForEachSample(const float* boxes, const float* scores, int num_boxes, + int num_classes, + std::map>* keep_indices); + void Compute(const float* boxes, const float* scores, + const std::vector& boxes_dim, + const std::vector& scores_dim); +}; +} // namespace backend + +} // namespace fastdeploy diff --git a/fastdeploy/backends/ort/ops/multiclass_nms.cc b/fastdeploy/backends/ort/ops/multiclass_nms.cc index 8c00dc7be..6f9f8f2a7 100644 --- a/fastdeploy/backends/ort/ops/multiclass_nms.cc +++ b/fastdeploy/backends/ort/ops/multiclass_nms.cc @@ -253,8 +253,5 @@ void MultiClassNmsKernel::GetAttribute(const OrtKernelInfo* info) { nms_top_k = ort_.KernelInfoGetAttribute(info, "nms_top_k"); normalized = ort_.KernelInfoGetAttribute(info, "normalized"); score_threshold = ort_.KernelInfoGetAttribute(info, "score_threshold"); - std::cout << background_label << " " << keep_top_k << " " << nms_eta << " " - << nms_threshold << " " << nms_top_k << " " << normalized << " " - << score_threshold << " " << std::endl; } } // namespace fastdeploy diff --git a/fastdeploy/backends/ort/ort_backend.cc b/fastdeploy/backends/ort/ort_backend.cc index f5d0bfdd9..27c746a9e 100644 --- a/fastdeploy/backends/ort/ort_backend.cc +++ b/fastdeploy/backends/ort/ort_backend.cc @@ -107,16 +107,26 @@ bool OrtBackend::InitFromPaddle(const std::string& model_file, #ifdef ENABLE_PADDLE_FRONTEND char* model_content_ptr; int model_content_size = 0; + + std::vector custom_ops; + for (auto& item : option.custom_op_info_) { + paddle2onnx::CustomOp op; + strcpy(op.op_name, item.first.c_str()); + strcpy(op.export_op_name, item.second.c_str()); + custom_ops.emplace_back(op); + } if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(), &model_content_ptr, &model_content_size, 11, true, - verbose, true, true, true)) { + verbose, true, true, true, custom_ops.data(), + custom_ops.size())) { FDERROR << "Error occured while export PaddlePaddle to ONNX format." << std::endl; return false; } + std::string onnx_model_proto(model_content_ptr, model_content_ptr + model_content_size); - delete model_content_ptr; + delete[] model_content_ptr; model_content_ptr = nullptr; return InitFromOnnx(onnx_model_proto, option, true); #else diff --git a/fastdeploy/backends/ort/ort_backend.h b/fastdeploy/backends/ort/ort_backend.h index 8556763e0..2dab03023 100644 --- a/fastdeploy/backends/ort/ort_backend.h +++ b/fastdeploy/backends/ort/ort_backend.h @@ -44,6 +44,10 @@ struct OrtBackendOption { int execution_mode = -1; bool use_gpu = false; int gpu_id = 0; + + // inside parameter, maybe remove next version + bool remove_multiclass_nms_ = false; + std::map custom_op_info_; }; class OrtBackend : public BaseBackend { diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index d050cc9f2..5fdf6ad0d 100644 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -162,18 +162,41 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file, } #ifdef ENABLE_PADDLE_FRONTEND + std::vector custom_ops; + for (auto& item : option.custom_op_info_) { + paddle2onnx::CustomOp op; + std::strcpy(op.op_name, item.first.c_str()); + std::strcpy(op.export_op_name, item.second.c_str()); + custom_ops.emplace_back(op); + } char* model_content_ptr; int model_content_size = 0; if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(), &model_content_ptr, &model_content_size, 11, true, - verbose, true, true, true)) { + verbose, true, true, true, custom_ops.data(), + custom_ops.size())) { FDERROR << "Error occured while export PaddlePaddle to ONNX format." << std::endl; return false; } + + if (option.remove_multiclass_nms_) { + char* new_model = nullptr; + int new_model_size = 0; + if (!paddle2onnx::RemoveMultiClassNMS(model_content_ptr, model_content_size, + &new_model, &new_model_size)) { + FDERROR << "Try to remove MultiClassNMS failed." << std::endl; + return false; + } + delete[] model_content_ptr; + std::string onnx_model_proto(new_model, new_model + new_model_size); + delete[] new_model; + return InitFromOnnx(onnx_model_proto, option, true); + } + std::string onnx_model_proto(model_content_ptr, model_content_ptr + model_content_size); - delete model_content_ptr; + delete[] model_content_ptr; model_content_ptr = nullptr; return InitFromOnnx(onnx_model_proto, option, true); #else diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index 27e6e552b..b2555c576 100644 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -50,6 +50,10 @@ struct TrtBackendOption { std::map> min_shape; std::map> opt_shape; std::string serialize_file = ""; + + // inside parameter, maybe remove next version + bool remove_multiclass_nms_ = false; + std::map custom_op_info_; }; std::vector toVec(const nvinfer1::Dims& dim); diff --git a/fastdeploy/fastdeploy_runtime.cc b/fastdeploy/fastdeploy_runtime.cc index 05af6e14e..66efd22e5 100644 --- a/fastdeploy/fastdeploy_runtime.cc +++ b/fastdeploy/fastdeploy_runtime.cc @@ -202,9 +202,6 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name, } else { trt_max_shape[input_name].assign(max_shape.begin(), max_shape.end()); } - FDINFO << trt_min_shape[input_name].size() << " " - << trt_opt_shape[input_name].size() << " " - << trt_max_shape[input_name].size() << std::endl; } void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } @@ -295,6 +292,11 @@ void Runtime::CreateOrtBackend() { ort_option.execution_mode = option.ort_execution_mode; ort_option.use_gpu = (option.device == Device::GPU) ? true : false; ort_option.gpu_id = option.device_id; + + // TODO(jiangjiajun): inside usage, maybe remove this later + ort_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; + ort_option.custom_op_info_ = option.custom_op_info_; + FDASSERT(option.model_format == Frontend::PADDLE || option.model_format == Frontend::ONNX, "OrtBackend only support model format of Frontend::PADDLE / " @@ -328,6 +330,11 @@ void Runtime::CreateTrtBackend() { trt_option.min_shape = option.trt_min_shape; trt_option.opt_shape = option.trt_opt_shape; trt_option.serialize_file = option.trt_serialize_file; + + // TODO(jiangjiajun): inside usage, maybe remove this later + trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; + trt_option.custom_op_info_ = option.custom_op_info_; + FDASSERT(option.model_format == Frontend::PADDLE || option.model_format == Frontend::ONNX, "TrtBackend only support model format of Frontend::PADDLE / " diff --git a/fastdeploy/fastdeploy_runtime.h b/fastdeploy/fastdeploy_runtime.h index d0f01069f..cdd2b3e4b 100644 --- a/fastdeploy/fastdeploy_runtime.h +++ b/fastdeploy/fastdeploy_runtime.h @@ -124,6 +124,12 @@ struct FASTDEPLOY_DECL RuntimeOption { std::string model_file = ""; // Path of model file std::string params_file = ""; // Path of parameters file, can be empty Frontend model_format = Frontend::AUTOREC; // format of input model + + // inside parameters, only for inside usage + // remove multiclass_nms in Paddle2ONNX + bool remove_multiclass_nms_ = false; + // for Paddle2ONNX to export custom operators + std::map custom_op_info_; }; struct FASTDEPLOY_DECL Runtime { diff --git a/fastdeploy/utils/utils.cc b/fastdeploy/utils/utils.cc index dfe5326d1..3899bcf5e 100644 --- a/fastdeploy/utils/utils.cc +++ b/fastdeploy/utils/utils.cc @@ -31,4 +31,19 @@ FDLogger& FDLogger::operator<<(std::ostream& (*os)(std::ostream&)) { return *this; } +bool ReadBinaryFromFile(const std::string& file, std::string* contents) { + std::ifstream fin(file, std::ios::in | std::ios::binary); + if (!fin.is_open()) { + FDERROR << "Failed to open file: " << file << " to read." << std::endl; + return false; + } + fin.seekg(0, std::ios::end); + contents->clear(); + contents->resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&(contents->at(0)), contents->size()); + fin.close(); + return true; +} + } // namespace fastdeploy diff --git a/fastdeploy/utils/utils.h b/fastdeploy/utils/utils.h index f427cd7a3..bde8e8d90 100644 --- a/fastdeploy/utils/utils.h +++ b/fastdeploy/utils/utils.h @@ -65,6 +65,9 @@ class FASTDEPLOY_DECL FDLogger { bool verbose_ = true; }; +FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, + std::string* contents); + #ifndef __REL_FILE__ #define __REL_FILE__ __FILE__ #endif diff --git a/fastdeploy/vision/ppdet/__init__.py b/fastdeploy/vision/ppdet/__init__.py index 069c9baa6..93aad6405 100644 --- a/fastdeploy/vision/ppdet/__init__.py +++ b/fastdeploy/vision/ppdet/__init__.py @@ -33,7 +33,6 @@ class PPYOLOE(FastDeployModel): model_format) assert self.initialized, "PPYOLOE model initialize failed." - def predict(self, input_image, conf_threshold=0.5, nms_iou_threshold=0.7): + def predict(self, input_image): assert input_image is not None, "The input image data is None." - return self._model.predict(input_image, conf_threshold, - nms_iou_threshold) + return self._model.predict(input_image) diff --git a/fastdeploy/vision/ppdet/ppdet_pybind.cc b/fastdeploy/vision/ppdet/ppdet_pybind.cc index 1f85f1967..bd1fc4621 100644 --- a/fastdeploy/vision/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/ppdet/ppdet_pybind.cc @@ -21,11 +21,10 @@ void BindPPDet(pybind11::module& m) { "PPYOLOE") .def(pybind11::init()) - .def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data, - float conf_threshold, float nms_iou_threshold) { + .def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); vision::DetectionResult res; - self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + self.Predict(&mat, &res); return res; }); } diff --git a/fastdeploy/vision/ppdet/ppyoloe.cc b/fastdeploy/vision/ppdet/ppyoloe.cc index ed8d1e46f..08c4073f2 100644 --- a/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/fastdeploy/vision/ppdet/ppyoloe.cc @@ -1,6 +1,9 @@ #include "fastdeploy/vision/ppdet/ppyoloe.h" #include "fastdeploy/vision/utils/utils.h" #include "yaml-cpp/yaml.h" +#ifdef ENABLE_PADDLE_FRONTEND +#include "paddle2onnx/converter.h" +#endif namespace fastdeploy { namespace vision { @@ -21,15 +24,38 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, } bool PPYOLOE::Initialize() { +#ifdef ENABLE_PADDLE_FRONTEND + // remove multiclass_nms3 now + // this is a trick operation for ppyoloe while inference on trt + if (runtime_option.model_format == Frontend::PADDLE) { + std::string contents; + if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { + return false; + } + auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); + if (reader.has_nms) { + has_nms_ = true; + } + } + runtime_option.remove_multiclass_nms_ = true; + runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS"; +#endif if (!BuildPreprocessPipelineFromConfig()) { FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; + << std::endl; return false; } if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + + if (has_nms_ && runtime_option.backend == Backend::TRT) { + FDINFO << "Detected operator multiclass_nms3 in your model, will replace " + "it with fastdeploy::backend::MultiClassNMS replace it." + << std::endl; + has_nms_ = false; + } return true; } @@ -40,14 +66,14 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { cfg = YAML::LoadFile(config_file_); } catch (YAML::BadFile& e) { FDERROR << "Failed to load yaml file " << config_file_ - << ", maybe you should check this file." << std::endl; + << ", maybe you should check this file." << std::endl; return false; } if (cfg["arch"].as() != "YOLO") { FDERROR << "Require the arch of model is YOLO, but arch defined in " - "config file is " - << cfg["arch"].as() << "." << std::endl; + "config file is " + << cfg["arch"].as() << "." << std::endl; return false; } processors_.push_back(std::make_shared()); @@ -77,7 +103,7 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() { processors_.push_back(std::make_shared()); } else { FDERROR << "Unexcepted preprocess operator: " << op_name << "." - << std::endl; + << std::endl; return false; } } @@ -90,7 +116,7 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; + << "." << std::endl; return false; } } @@ -110,32 +136,70 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { } bool PPYOLOE::Postprocess(std::vector& infer_result, - DetectionResult* result, float conf_threshold, - float nms_threshold) { + DetectionResult* result) { FDASSERT(infer_result[1].shape[0] == 1, "Only support batch = 1 in FastDeploy now."); - int box_num = 0; - if (infer_result[1].dtype == FDDataType::INT32) { - box_num = *(static_cast(infer_result[1].Data())); - } else if (infer_result[1].dtype == FDDataType::INT64) { - box_num = *(static_cast(infer_result[1].Data())); - } else { - FDASSERT( - false, - "The output box_num of PPYOLOE model should be type of int32/int64."); - } - result->Reserve(box_num); - float* box_data = static_cast(infer_result[0].Data()); - for (size_t i = 0; i < box_num; ++i) { - if (box_data[i * 6 + 1] < conf_threshold) { - continue; + + if (!has_nms_) { + int boxes_index = 0; + int scores_index = 1; + if (infer_result[0].shape[1] == infer_result[1].shape[2]) { + boxes_index = 0; + scores_index = 1; + } else if (infer_result[0].shape[2] == infer_result[1].shape[1]) { + boxes_index = 1; + scores_index = 0; + } else { + FDERROR << "The shape of boxes and scores should be [batch, boxes_num, " + "4], [batch, classes_num, boxes_num]" + << std::endl; + return false; + } + + backend::MultiClassNMS nms; + nms.background_label = background_label; + nms.keep_top_k = keep_top_k; + nms.nms_eta = nms_eta; + nms.nms_threshold = nms_threshold; + nms.score_threshold = score_threshold; + nms.nms_top_k = nms_top_k; + nms.normalized = normalized; + nms.Compute(static_cast(infer_result[boxes_index].Data()), + static_cast(infer_result[scores_index].Data()), + infer_result[boxes_index].shape, + infer_result[scores_index].shape); + if (nms.out_num_rois_data[0] > 0) { + result->Reserve(nms.out_num_rois_data[0]); + } + for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) { + result->label_ids.push_back(nms.out_box_data[i * 6]); + result->scores.push_back(nms.out_box_data[i * 6 + 1]); + result->boxes.emplace_back(std::array{ + nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3], + nms.out_box_data[i * 6 + 4] - nms.out_box_data[i * 6 + 2], + nms.out_box_data[i * 6 + 5] - nms.out_box_data[i * 6 + 3]}); + } + } else { + int box_num = 0; + if (infer_result[1].dtype == FDDataType::INT32) { + box_num = *(static_cast(infer_result[1].Data())); + } else if (infer_result[1].dtype == FDDataType::INT64) { + box_num = *(static_cast(infer_result[1].Data())); + } else { + FDASSERT( + false, + "The output box_num of PPYOLOE model should be type of int32/int64."); + } + result->Reserve(box_num); + float* box_data = static_cast(infer_result[0].Data()); + for (size_t i = 0; i < box_num; ++i) { + result->label_ids.push_back(box_data[i * 6]); + result->scores.push_back(box_data[i * 6 + 1]); + result->boxes.emplace_back( + std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], + box_data[i * 6 + 4] - box_data[i * 6 + 2], + box_data[i * 6 + 5] - box_data[i * 6 + 3]}); } - result->label_ids.push_back(box_data[i * 6]); - result->scores.push_back(box_data[i * 6 + 1]); - result->boxes.emplace_back( - std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], - box_data[i * 6 + 4] - box_data[i * 6 + 2], - box_data[i * 6 + 5] - box_data[i * 6 + 3]}); } return true; } @@ -157,7 +221,7 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result, return false; } - if (!Postprocess(infer_result, result, conf_threshold, iou_threshold)) { + if (!Postprocess(infer_result, result)) { FDERROR << "Failed to postprocess while using model:" << ModelName() << "." << std::endl; return false; diff --git a/fastdeploy/vision/ppdet/ppyoloe.h b/fastdeploy/vision/ppdet/ppyoloe.h index a3db268ca..ec22aa2ce 100644 --- a/fastdeploy/vision/ppdet/ppyoloe.h +++ b/fastdeploy/vision/ppdet/ppyoloe.h @@ -25,8 +25,7 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { virtual bool Preprocess(Mat* mat, std::vector* outputs); virtual bool Postprocess(std::vector& infer_result, - DetectionResult* result, float conf_threshold, - float nms_threshold); + DetectionResult* result); virtual bool Predict(cv::Mat* im, DetectionResult* result, float conf_threshold = 0.5, float nms_threshold = 0.7); @@ -34,10 +33,15 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { private: std::vector> processors_; std::string config_file_; - // PaddleDetection can export model without nms - // This flag will help us to handle the different - // situation - bool has_nms_; + // configuration for nms + int64_t background_label = -1; + int64_t keep_top_k = 300; + float nms_eta = 1.0; + float nms_threshold = 0.7; + float score_threshold = 0.01; + int64_t nms_top_k = 10000; + bool normalized = true; + bool has_nms_ = false; }; } // namespace ppdet } // namespace vision