Support remove multiclass_nms to enable ppyoloe to tensorrt (#40)

* Add custom operator for onnxruntime ans fix paddle backend

* Polish cmake files and runtime apis

* Remove copy libraries

* fix some issue

* fix bug

* fix bug

* Support remove multiclass_nms to enable paddledetection run tensorrt

* Support remove multiclass_nms to enable paddledetection run tensorrt

* Support remove multiclass_nms to enable paddledetection run tensorrt

* Support remove multiclass_nms to enable paddledetection run tensorrt

* add common operator multiclassnms

* fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
This commit is contained in:
Jason
2022-07-26 11:16:01 +08:00
committed by GitHub
parent fc71d79e58
commit 17e4dc6b5e
18 changed files with 460 additions and 56 deletions

View File

@@ -27,7 +27,7 @@ set(ONNXRUNTIME_LIB_DIR
CACHE PATH "onnxruntime lib directory." FORCE)
set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}" "${ONNXRUNTIME_LIB_DIR}")
set(ONNXRUNTIME_VERSION "1.11.1")
set(ONNXRUNTIME_VERSION "1.12.0")
set(ONNXRUNTIME_URL_PREFIX "https://bj.bcebos.com/paddle2onnx/libs/")
if(WIN32)

View File

@@ -43,7 +43,7 @@ else()
endif(WIN32)
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/")
set(PADDLE2ONNX_VERSION "1.0.0rc1")
set(PADDLE2ONNX_VERSION "1.0.0rc2")
if(WIN32)
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
elseif(APPLE)

View File

@@ -18,7 +18,7 @@
#include <memory>
#include <string>
#include <vector>
#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
@@ -45,4 +45,4 @@ class BaseBackend {
std::vector<FDTensor>* outputs) = 0;
};
} // namespace fastdeploy
} // namespace fastdeploy

View File

@@ -0,0 +1,224 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/common/multiclass_nms.h"
#include <algorithm>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
namespace fastdeploy {
namespace backend {
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
void GetMaxScoreIndex(const float* scores, const int& score_size,
const float& threshold, const int& top_k,
std::vector<std::pair<float, int>>* sorted_indices) {
for (size_t i = 0; i < score_size; ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
float BBoxArea(const float* box, const bool& normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return 0.f;
} else {
const float w = box[2] - box[0];
const float h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
float JaccardOverlap(const float* box1, const float* box2,
const bool& normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return 0.f;
} else {
const float inter_xmin = std::max(box1[0], box2[0]);
const float inter_ymin = std::max(box1[1], box2[1]);
const float inter_xmax = std::min(box1[2], box2[2]);
const float inter_ymax = std::min(box1[3], box2[3]);
float norm = normalized ? 0.0f : 1.0f;
float inter_w = inter_xmax - inter_xmin + norm;
float inter_h = inter_ymax - inter_ymin + norm;
const float inter_area = inter_w * inter_h;
const float bbox1_area = BBoxArea(box1, normalized);
const float bbox2_area = BBoxArea(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
void MultiClassNMS::FastNMS(const float* boxes, const float* scores,
const int& num_boxes,
std::vector<int>* keep_indices) {
std::vector<std::pair<float, int>> sorted_indices;
GetMaxScoreIndex(scores, num_boxes, score_threshold, nms_top_k,
&sorted_indices);
float adaptive_threshold = nms_threshold;
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < keep_indices->size(); ++k) {
if (!keep) {
break;
}
const int kept_idx = (*keep_indices)[k];
float overlap =
JaccardOverlap(boxes + idx * 4, boxes + kept_idx * 4, normalized);
keep = overlap <= adaptive_threshold;
}
if (keep) {
keep_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) {
adaptive_threshold *= nms_eta;
}
}
}
int MultiClassNMS::NMSForEachSample(
const float* boxes, const float* scores, int num_boxes, int num_classes,
std::map<int, std::vector<int>>* keep_indices) {
for (int i = 0; i < num_classes; ++i) {
if (i == background_label) {
continue;
}
const float* score_for_class_i = scores + i * num_boxes;
FastNMS(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i]));
}
int num_det = 0;
for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) {
num_det += iter->second.size();
}
if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *keep_indices) {
int label = it.first;
const float* current_score = scores + label * num_boxes;
auto& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(current_score[idx], std::make_pair(label, idx)));
}
}
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
new_indices.swap(*keep_indices);
num_det = keep_top_k;
}
return num_det;
}
void MultiClassNMS::Compute(const float* boxes_data, const float* scores_data,
const std::vector<int64_t>& boxes_dim,
const std::vector<int64_t>& scores_dim) {
int score_size = scores_dim.size();
int64_t batch_size = scores_dim[0];
int64_t box_dim = boxes_dim[2];
int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0;
FDASSERT(score_size == 3, "Require rank of input scores be 3, but now it's " +
std::to_string(score_size) + ".");
FDASSERT(boxes_dim[2] == 4,
"Require the 3-dimension of input boxes be 4, but now it's " +
std::to_string(boxes_dim[2]) + ".");
out_num_rois_data.resize(batch_size);
std::vector<std::map<int, std::vector<int>>> all_indices;
for (size_t i = 0; i < batch_size; ++i) {
std::map<int, std::vector<int>> indices; // indices kept for each class
const float* current_boxes_ptr =
boxes_data + i * boxes_dim[1] * boxes_dim[2];
const float* current_scores_ptr =
scores_data + i * scores_dim[1] * scores_dim[2];
int num = NMSForEachSample(current_boxes_ptr, current_scores_ptr,
boxes_dim[1], scores_dim[1], &indices);
num_nmsed_out += num;
out_num_rois_data[i] = num;
all_indices.emplace_back(indices);
}
std::vector<int64_t> out_box_dims = {num_nmsed_out, 6};
std::vector<int64_t> out_index_dims = {num_nmsed_out, 1};
if (num_nmsed_out == 0) {
for (size_t i = 0; i < batch_size; ++i) {
out_num_rois_data[i] = 0;
}
return;
}
out_box_data.resize(num_nmsed_out * 6);
out_index_data.resize(num_nmsed_out);
int count = 0;
for (size_t i = 0; i < batch_size; ++i) {
const float* current_boxes_ptr =
boxes_data + i * boxes_dim[1] * boxes_dim[2];
const float* current_scores_ptr =
scores_data + i * scores_dim[1] * scores_dim[2];
for (const auto& it : all_indices[i]) {
int label = it.first;
const auto& indices = it.second;
const float* current_scores_class_ptr =
current_scores_ptr + label * scores_dim[2];
for (size_t j = 0; j < indices.size(); ++j) {
int start = count * 6;
out_box_data[start] = label;
out_box_data[start + 1] = current_scores_class_ptr[indices[j]];
out_box_data[start + 2] = current_boxes_ptr[indices[j] * 4];
out_box_data[start + 3] = current_boxes_ptr[indices[j] * 4 + 1];
out_box_data[start + 4] = current_boxes_ptr[indices[j] * 4 + 2];
out_box_data[start + 5] = current_boxes_ptr[indices[j] * 4 + 3];
out_index_data[count] = i * boxes_dim[1] + indices[j];
count += 1;
}
}
}
}
} // namespace backend
} // namespace fastdeploy

View File

@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace fastdeploy {
namespace backend {
struct MultiClassNMS {
int64_t background_label = -1;
int64_t keep_top_k = -1;
float nms_eta;
float nms_threshold = 0.7;
int64_t nms_top_k;
bool normalized;
float score_threshold;
std::vector<int32_t> out_num_rois_data;
std::vector<int32_t> out_index_data;
std::vector<float> out_box_data;
void FastNMS(const float* boxes, const float* scores, const int& num_boxes,
std::vector<int>* keep_indices);
int NMSForEachSample(const float* boxes, const float* scores, int num_boxes,
int num_classes,
std::map<int, std::vector<int>>* keep_indices);
void Compute(const float* boxes, const float* scores,
const std::vector<int64_t>& boxes_dim,
const std::vector<int64_t>& scores_dim);
};
} // namespace backend
} // namespace fastdeploy

View File

@@ -253,8 +253,5 @@ void MultiClassNmsKernel::GetAttribute(const OrtKernelInfo* info) {
nms_top_k = ort_.KernelInfoGetAttribute<int64_t>(info, "nms_top_k");
normalized = ort_.KernelInfoGetAttribute<int64_t>(info, "normalized");
score_threshold = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
std::cout << background_label << " " << keep_top_k << " " << nms_eta << " "
<< nms_threshold << " " << nms_top_k << " " << normalized << " "
<< score_threshold << " " << std::endl;
}
} // namespace fastdeploy

View File

@@ -107,16 +107,26 @@ bool OrtBackend::InitFromPaddle(const std::string& model_file,
#ifdef ENABLE_PADDLE_FRONTEND
char* model_content_ptr;
int model_content_size = 0;
std::vector<paddle2onnx::CustomOp> custom_ops;
for (auto& item : option.custom_op_info_) {
paddle2onnx::CustomOp op;
strcpy(op.op_name, item.first.c_str());
strcpy(op.export_op_name, item.second.c_str());
custom_ops.emplace_back(op);
}
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
&model_content_ptr, &model_content_size, 11, true,
verbose, true, true, true)) {
verbose, true, true, true, custom_ops.data(),
custom_ops.size())) {
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
<< std::endl;
return false;
}
std::string onnx_model_proto(model_content_ptr,
model_content_ptr + model_content_size);
delete model_content_ptr;
delete[] model_content_ptr;
model_content_ptr = nullptr;
return InitFromOnnx(onnx_model_proto, option, true);
#else

View File

@@ -44,6 +44,10 @@ struct OrtBackendOption {
int execution_mode = -1;
bool use_gpu = false;
int gpu_id = 0;
// inside parameter, maybe remove next version
bool remove_multiclass_nms_ = false;
std::map<std::string, std::string> custom_op_info_;
};
class OrtBackend : public BaseBackend {

View File

@@ -162,18 +162,41 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
}
#ifdef ENABLE_PADDLE_FRONTEND
std::vector<paddle2onnx::CustomOp> custom_ops;
for (auto& item : option.custom_op_info_) {
paddle2onnx::CustomOp op;
std::strcpy(op.op_name, item.first.c_str());
std::strcpy(op.export_op_name, item.second.c_str());
custom_ops.emplace_back(op);
}
char* model_content_ptr;
int model_content_size = 0;
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
&model_content_ptr, &model_content_size, 11, true,
verbose, true, true, true)) {
verbose, true, true, true, custom_ops.data(),
custom_ops.size())) {
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
<< std::endl;
return false;
}
if (option.remove_multiclass_nms_) {
char* new_model = nullptr;
int new_model_size = 0;
if (!paddle2onnx::RemoveMultiClassNMS(model_content_ptr, model_content_size,
&new_model, &new_model_size)) {
FDERROR << "Try to remove MultiClassNMS failed." << std::endl;
return false;
}
delete[] model_content_ptr;
std::string onnx_model_proto(new_model, new_model + new_model_size);
delete[] new_model;
return InitFromOnnx(onnx_model_proto, option, true);
}
std::string onnx_model_proto(model_content_ptr,
model_content_ptr + model_content_size);
delete model_content_ptr;
delete[] model_content_ptr;
model_content_ptr = nullptr;
return InitFromOnnx(onnx_model_proto, option, true);
#else

View File

@@ -50,6 +50,10 @@ struct TrtBackendOption {
std::map<std::string, std::vector<int32_t>> min_shape;
std::map<std::string, std::vector<int32_t>> opt_shape;
std::string serialize_file = "";
// inside parameter, maybe remove next version
bool remove_multiclass_nms_ = false;
std::map<std::string, std::string> custom_op_info_;
};
std::vector<int> toVec(const nvinfer1::Dims& dim);

View File

@@ -202,9 +202,6 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name,
} else {
trt_max_shape[input_name].assign(max_shape.begin(), max_shape.end());
}
FDINFO << trt_min_shape[input_name].size() << " "
<< trt_opt_shape[input_name].size() << " "
<< trt_max_shape[input_name].size() << std::endl;
}
void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }
@@ -295,6 +292,11 @@ void Runtime::CreateOrtBackend() {
ort_option.execution_mode = option.ort_execution_mode;
ort_option.use_gpu = (option.device == Device::GPU) ? true : false;
ort_option.gpu_id = option.device_id;
// TODO(jiangjiajun): inside usage, maybe remove this later
ort_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;
ort_option.custom_op_info_ = option.custom_op_info_;
FDASSERT(option.model_format == Frontend::PADDLE ||
option.model_format == Frontend::ONNX,
"OrtBackend only support model format of Frontend::PADDLE / "
@@ -328,6 +330,11 @@ void Runtime::CreateTrtBackend() {
trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file;
// TODO(jiangjiajun): inside usage, maybe remove this later
trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;
trt_option.custom_op_info_ = option.custom_op_info_;
FDASSERT(option.model_format == Frontend::PADDLE ||
option.model_format == Frontend::ONNX,
"TrtBackend only support model format of Frontend::PADDLE / "

View File

@@ -124,6 +124,12 @@ struct FASTDEPLOY_DECL RuntimeOption {
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
Frontend model_format = Frontend::AUTOREC; // format of input model
// inside parameters, only for inside usage
// remove multiclass_nms in Paddle2ONNX
bool remove_multiclass_nms_ = false;
// for Paddle2ONNX to export custom operators
std::map<std::string, std::string> custom_op_info_;
};
struct FASTDEPLOY_DECL Runtime {

View File

@@ -31,4 +31,19 @@ FDLogger& FDLogger::operator<<(std::ostream& (*os)(std::ostream&)) {
return *this;
}
bool ReadBinaryFromFile(const std::string& file, std::string* contents) {
std::ifstream fin(file, std::ios::in | std::ios::binary);
if (!fin.is_open()) {
FDERROR << "Failed to open file: " << file << " to read." << std::endl;
return false;
}
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
return true;
}
} // namespace fastdeploy

View File

@@ -65,6 +65,9 @@ class FASTDEPLOY_DECL FDLogger {
bool verbose_ = true;
};
FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
std::string* contents);
#ifndef __REL_FILE__
#define __REL_FILE__ __FILE__
#endif

View File

@@ -33,7 +33,6 @@ class PPYOLOE(FastDeployModel):
model_format)
assert self.initialized, "PPYOLOE model initialize failed."
def predict(self, input_image, conf_threshold=0.5, nms_iou_threshold=0.7):
def predict(self, input_image):
assert input_image is not None, "The input image data is None."
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
return self._model.predict(input_image)

View File

@@ -21,11 +21,10 @@ void BindPPDet(pybind11::module& m) {
"PPYOLOE")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>())
.def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
.def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
self.Predict(&mat, &res);
return res;
});
}

View File

@@ -1,6 +1,9 @@
#include "fastdeploy/vision/ppdet/ppyoloe.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
#endif
namespace fastdeploy {
namespace vision {
@@ -21,15 +24,38 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
}
bool PPYOLOE::Initialize() {
#ifdef ENABLE_PADDLE_FRONTEND
// remove multiclass_nms3 now
// this is a trick operation for ppyoloe while inference on trt
if (runtime_option.model_format == Frontend::PADDLE) {
std::string contents;
if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) {
return false;
}
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
if (reader.has_nms) {
has_nms_ = true;
}
}
runtime_option.remove_multiclass_nms_ = true;
runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS";
#endif
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
if (has_nms_ && runtime_option.backend == Backend::TRT) {
FDINFO << "Detected operator multiclass_nms3 in your model, will replace "
"it with fastdeploy::backend::MultiClassNMS replace it."
<< std::endl;
has_nms_ = false;
}
return true;
}
@@ -40,14 +66,14 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
<< ", maybe you should check this file." << std::endl;
return false;
}
if (cfg["arch"].as<std::string>() != "YOLO") {
FDERROR << "Require the arch of model is YOLO, but arch defined in "
"config file is "
<< cfg["arch"].as<std::string>() << "." << std::endl;
"config file is "
<< cfg["arch"].as<std::string>() << "." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
@@ -77,7 +103,7 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
processors_.push_back(std::make_shared<HWC2CHW>());
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
<< std::endl;
return false;
}
}
@@ -90,7 +116,7 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
<< "." << std::endl;
return false;
}
}
@@ -110,32 +136,70 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
}
bool PPYOLOE::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result, float conf_threshold,
float nms_threshold) {
DetectionResult* result) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
int box_num = 0;
if (infer_result[1].dtype == FDDataType::INT32) {
box_num = *(static_cast<int32_t*>(infer_result[1].Data()));
} else if (infer_result[1].dtype == FDDataType::INT64) {
box_num = *(static_cast<int64_t*>(infer_result[1].Data()));
} else {
FDASSERT(
false,
"The output box_num of PPYOLOE model should be type of int32/int64.");
}
result->Reserve(box_num);
float* box_data = static_cast<float*>(infer_result[0].Data());
for (size_t i = 0; i < box_num; ++i) {
if (box_data[i * 6 + 1] < conf_threshold) {
continue;
if (!has_nms_) {
int boxes_index = 0;
int scores_index = 1;
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = background_label;
nms.keep_top_k = keep_top_k;
nms.nms_eta = nms_eta;
nms.nms_threshold = nms_threshold;
nms.score_threshold = score_threshold;
nms.nms_top_k = nms_top_k;
nms.normalized = normalized;
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
static_cast<float*>(infer_result[scores_index].Data()),
infer_result[boxes_index].shape,
infer_result[scores_index].shape);
if (nms.out_num_rois_data[0] > 0) {
result->Reserve(nms.out_num_rois_data[0]);
}
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
result->label_ids.push_back(nms.out_box_data[i * 6]);
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
result->boxes.emplace_back(std::array<float, 4>{
nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3],
nms.out_box_data[i * 6 + 4] - nms.out_box_data[i * 6 + 2],
nms.out_box_data[i * 6 + 5] - nms.out_box_data[i * 6 + 3]});
}
} else {
int box_num = 0;
if (infer_result[1].dtype == FDDataType::INT32) {
box_num = *(static_cast<int32_t*>(infer_result[1].Data()));
} else if (infer_result[1].dtype == FDDataType::INT64) {
box_num = *(static_cast<int64_t*>(infer_result[1].Data()));
} else {
FDASSERT(
false,
"The output box_num of PPYOLOE model should be type of int32/int64.");
}
result->Reserve(box_num);
float* box_data = static_cast<float*>(infer_result[0].Data());
for (size_t i = 0; i < box_num; ++i) {
result->label_ids.push_back(box_data[i * 6]);
result->scores.push_back(box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
box_data[i * 6 + 4] - box_data[i * 6 + 2],
box_data[i * 6 + 5] - box_data[i * 6 + 3]});
}
result->label_ids.push_back(box_data[i * 6]);
result->scores.push_back(box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
box_data[i * 6 + 4] - box_data[i * 6 + 2],
box_data[i * 6 + 5] - box_data[i * 6 + 3]});
}
return true;
}
@@ -157,7 +221,7 @@ bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result,
return false;
}
if (!Postprocess(infer_result, result, conf_threshold, iou_threshold)) {
if (!Postprocess(infer_result, result)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl;
return false;

View File

@@ -25,8 +25,7 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result, float conf_threshold,
float nms_threshold);
DetectionResult* result);
virtual bool Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold = 0.5, float nms_threshold = 0.7);
@@ -34,10 +33,15 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
private:
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
// PaddleDetection can export model without nms
// This flag will help us to handle the different
// situation
bool has_nms_;
// configuration for nms
int64_t background_label = -1;
int64_t keep_top_k = 300;
float nms_eta = 1.0;
float nms_threshold = 0.7;
float score_threshold = 0.01;
int64_t nms_top_k = 10000;
bool normalized = true;
bool has_nms_ = false;
};
} // namespace ppdet
} // namespace vision