mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] Add DecodeProcess For PPDet (#1127)
* 更新ppdet * 更新ppdet * 更新ppdet * 更新ppdet * 更新ppdet * 新增ppdet_decode * 更新多batch支持 * 更新多batch支持 * 更新多batch支持 * 更新注释内容 * 尝试解决pybind问题 * 尝试解决pybind的问题 * 尝试解决pybind的问题 * 重构代码 * 重构代码 * 重构代码 * 按照要求修改 * 修复部分bug 加入pybind * 修复pybind * 修复pybind错误的问题
This commit is contained in:
@@ -12,13 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/runtime.h"
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/runtime.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "wav.h"
|
||||
#include "./wav.h"
|
||||
|
||||
class Vad : public fastdeploy::FastDeployModel {
|
||||
public:
|
||||
@@ -29,8 +27,7 @@ class Vad:public fastdeploy::FastDeployModel{
|
||||
fastdeploy::RuntimeOption()) {
|
||||
valid_cpu_backends = {fastdeploy::Backend::ORT,
|
||||
fastdeploy::Backend::OPENVINO};
|
||||
valid_gpu_backends = {fastdeploy::Backend::ORT,
|
||||
fastdeploy::Backend::TRT};
|
||||
valid_gpu_backends = {fastdeploy::Backend::ORT, fastdeploy::Backend::TRT};
|
||||
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = fastdeploy::ModelFormat::ONNX;
|
||||
@@ -38,22 +35,18 @@ class Vad:public fastdeploy::FastDeployModel{
|
||||
runtime_option.params_file = "";
|
||||
}
|
||||
|
||||
void init() {
|
||||
initialized = Initialize();
|
||||
}
|
||||
void init() { initialized = Initialize(); }
|
||||
|
||||
void setAudioCofig(
|
||||
int sr, int frame_ms, float threshold,
|
||||
void setAudioCofig(int sr, int frame_ms, float threshold,
|
||||
int min_silence_duration_ms, int speech_pad_ms);
|
||||
|
||||
void loadAudio(const std::string& wavPath);
|
||||
|
||||
bool Predict();
|
||||
|
||||
std::vector<std::map<std::string, float>> getResult(
|
||||
float removeThreshold = 1.6,
|
||||
float expandHeadThreshold = 0.32, float expandTailThreshold = 0,
|
||||
float mergeThreshold = 0.3);
|
||||
std::vector<std::map<std::string, float>>
|
||||
getResult(float removeThreshold = 1.6, float expandHeadThreshold = 0.32,
|
||||
float expandTailThreshold = 0, float mergeThreshold = 0.3);
|
||||
|
||||
private:
|
||||
bool Initialize();
|
||||
|
@@ -12,13 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace wav {
|
||||
@@ -54,8 +52,7 @@ class WavReader {
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
fprintf(stderr,
|
||||
"WaveData: expect PCM format data "
|
||||
fprintf(stderr, "WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
@@ -131,11 +128,8 @@ class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
: data_(data), num_samples_(num_samples), num_channel_(num_channel),
|
||||
sample_rate_(sample_rate), bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
|
@@ -21,10 +21,18 @@ def parse_arguments():
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_file", required=True, help="Path of rknn model.")
|
||||
parser.add_argument("--config_file", required=True, help="Path of config.")
|
||||
"--model_file",
|
||||
default="./picodet_s_416_coco_lcnet_non_postprocess/picodet_xs_416_coco_lcnet.onnx",
|
||||
help="Path of rknn model.")
|
||||
parser.add_argument(
|
||||
"--image", type=str, required=True, help="Path of test image file.")
|
||||
"--config_file",
|
||||
default="./picodet_s_416_coco_lcnet_non_postprocess/infer_cfg.yml",
|
||||
help="Path of config.")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default="./000000014439.jpg",
|
||||
help="Path of test image file.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -37,14 +45,14 @@ if __name__ == "__main__":
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = fd.RuntimeOption()
|
||||
runtime_option.use_rknpu2()
|
||||
runtime_option.use_cpu()
|
||||
|
||||
model = fd.vision.detection.PicoDet(
|
||||
model = fd.vision.detection.PPYOLOE(
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=runtime_option,
|
||||
model_format=fd.ModelFormat.RKNN)
|
||||
model_format=fd.ModelFormat.ONNX)
|
||||
|
||||
model.postprocessor.apply_decode_and_nms()
|
||||
|
||||
|
13
fastdeploy/vision/detection/ppdet/base.cc
Executable file → Normal file
13
fastdeploy/vision/detection/ppdet/base.cc
Executable file → Normal file
@@ -1,7 +1,8 @@
|
||||
#include "fastdeploy/vision/detection/ppdet/base.h"
|
||||
|
||||
#include "fastdeploy/utils/unique_ptr.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
#include "fastdeploy/utils/unique_ptr.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -12,7 +13,7 @@ PPDetBase::PPDetBase(const std::string& model_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format)
|
||||
: preprocessor_(config_file) {
|
||||
: preprocessor_(config_file), postprocessor_(config_file) {
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
@@ -20,7 +21,8 @@ PPDetBase::PPDetBase(const std::string& model_file,
|
||||
}
|
||||
|
||||
std::unique_ptr<PPDetBase> PPDetBase::Clone() const {
|
||||
std::unique_ptr<PPDetBase> clone_model = fastdeploy::utils::make_unique<PPDetBase>(PPDetBase(*this));
|
||||
std::unique_ptr<PPDetBase> clone_model =
|
||||
fastdeploy::utils::make_unique<PPDetBase>(PPDetBase(*this));
|
||||
clone_model->SetRuntime(clone_model->CloneRuntime());
|
||||
return clone_model;
|
||||
}
|
||||
@@ -57,8 +59,9 @@ bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs,
|
||||
reused_input_tensors_[1].name = "scale_factor";
|
||||
reused_input_tensors_[2].name = "im_shape";
|
||||
|
||||
if(postprocessor_.DecodeAndNMSApplied()){
|
||||
postprocessor_.SetScaleFactor(static_cast<float*>(reused_input_tensors_[1].Data()));
|
||||
if (NumInputsOfRuntime() == 1) {
|
||||
auto scale_factor = static_cast<float*>(reused_input_tensors_[1].Data());
|
||||
postprocessor_.SetScaleFactor({scale_factor[0], scale_factor[1]});
|
||||
}
|
||||
|
||||
// Some models don't need scale_factor and im_shape as input
|
||||
|
@@ -14,6 +14,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/base.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -35,8 +36,8 @@ class FASTDEPLOY_DECL PicoDet : public PPDetBase {
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
|
||||
Backend::PDINFER, Backend::LITE};
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
valid_rknpu_backends = {Backend::RKNPU2};
|
||||
valid_kunlunxin_backends = {Backend::LITE};
|
||||
@@ -64,8 +65,8 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
|
||||
Backend::PDINFER, Backend::LITE};
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
valid_timvx_backends = {Backend::LITE};
|
||||
valid_kunlunxin_backends = {Backend::LITE};
|
||||
@@ -253,7 +254,8 @@ class FASTDEPLOY_DECL PaddleYOLOv8 : public PPDetBase {
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE};
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
valid_kunlunxin_backends = {Backend::LITE};
|
||||
valid_ascend_backends = {Backend::LITE};
|
||||
@@ -294,7 +296,8 @@ class FASTDEPLOY_DECL CascadeRCNN : public PPDetBase {
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const {
|
||||
return "PaddleDetection/CascadeRCNN"; }
|
||||
return "PaddleDetection/CascadeRCNN";
|
||||
}
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PSSDet : public PPDetBase {
|
||||
|
@@ -20,6 +20,26 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
/** \brief Config for PaddleMultiClassNMS
|
||||
* \param[in] background_label the value of background label
|
||||
* \param[in] keep_top_k the value of keep_top_k
|
||||
* \param[in] nms_eta the value of nms_eta
|
||||
* \param[in] nms_threshold a dict that contains the arguments of nms operations
|
||||
* \param[in] nms_top_k if there are more than max_num bboxes after NMS, only top max_num will be kept.
|
||||
* \param[in] normalized Determine whether normalized is required
|
||||
* \param[in] score_threshold bbox threshold, bboxes with scores lower than it will not be considered.
|
||||
*/
|
||||
struct NMSOption{
|
||||
NMSOption() = default;
|
||||
int64_t background_label = -1;
|
||||
int64_t keep_top_k = 100;
|
||||
float nms_eta = 1.0;
|
||||
float nms_threshold = 0.5;
|
||||
int64_t nms_top_k = 1000;
|
||||
bool normalized = true;
|
||||
float score_threshold = 0.3;
|
||||
};
|
||||
|
||||
struct PaddleMultiClassNMS {
|
||||
int64_t background_label = -1;
|
||||
int64_t keep_top_k = -1;
|
||||
@@ -40,6 +60,16 @@ struct PaddleMultiClassNMS {
|
||||
void Compute(const float* boxes, const float* scores,
|
||||
const std::vector<int64_t>& boxes_dim,
|
||||
const std::vector<int64_t>& scores_dim);
|
||||
|
||||
void SetNMSOption(const struct NMSOption &nms_option){
|
||||
background_label = nms_option.background_label;
|
||||
keep_top_k = nms_option.keep_top_k;
|
||||
nms_eta = nms_option.nms_eta;
|
||||
nms_threshold = nms_option.nms_threshold;
|
||||
nms_top_k = nms_option.nms_top_k;
|
||||
normalized = nms_option.normalized;
|
||||
score_threshold = nms_option.score_threshold;
|
||||
}
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
|
88
fastdeploy/vision/detection/ppdet/postprocessor.cc
Executable file → Normal file
88
fastdeploy/vision/detection/ppdet/postprocessor.cc
Executable file → Normal file
@@ -14,7 +14,6 @@
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -32,7 +31,7 @@ bool PaddleDetPostprocessor::ProcessMask(
|
||||
int64_t out_mask_h = shape[1];
|
||||
int64_t out_mask_w = shape[2];
|
||||
int64_t out_mask_numel = shape[1] * shape[2];
|
||||
const uint8_t* data = reinterpret_cast<const uint8_t*>(tensor.CpuData());
|
||||
const auto* data = reinterpret_cast<const uint8_t*>(tensor.CpuData());
|
||||
int index = 0;
|
||||
|
||||
for (int i = 0; i < results->size(); ++i) {
|
||||
@@ -50,7 +49,7 @@ bool PaddleDetPostprocessor::ProcessMask(
|
||||
(*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w};
|
||||
const uint8_t* current_ptr = data + index * out_mask_numel;
|
||||
|
||||
uint8_t* keep_mask_ptr =
|
||||
auto* keep_mask_ptr =
|
||||
reinterpret_cast<uint8_t*>((*results)[i].masks[j].Data());
|
||||
for (int row = y1; row < y2; ++row) {
|
||||
size_t keep_nbytes_in_col = keep_mask_w * sizeof(uint8_t);
|
||||
@@ -67,16 +66,9 @@ bool PaddleDetPostprocessor::ProcessMask(
|
||||
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (DecodeAndNMSApplied()) {
|
||||
FDASSERT(tensors.size() == 2,
|
||||
"While postprocessing with ApplyDecodeAndNMS, "
|
||||
"there should be 2 outputs for this model, but now it's %zu.",
|
||||
tensors.size());
|
||||
FDASSERT(tensors[0].shape.size() == 3,
|
||||
"While postprocessing with ApplyDecodeAndNMS, "
|
||||
"the rank of the first outputs should be 3, but now it's %zu",
|
||||
tensors[0].shape.size());
|
||||
return ProcessUnDecodeResults(tensors, results);
|
||||
}
|
||||
|
||||
// Get number of boxes for each input image
|
||||
std::vector<int> num_boxes(tensors[1].shape[0]);
|
||||
int total_num_boxes = 0;
|
||||
@@ -152,77 +144,27 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
return ProcessMask(tensors[2], results);
|
||||
}
|
||||
|
||||
void PaddleDetPostprocessor::ApplyDecodeAndNMS() {
|
||||
apply_decode_and_nms_ = true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::ProcessUnDecodeResults(
|
||||
const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (tensors.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
results->resize(tensors[0].Shape()[0]);
|
||||
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
if (tensors[0].shape[1] == tensors[1].shape[2]) {
|
||||
boxes_index = 0;
|
||||
scores_index = 1;
|
||||
} else if (tensors[0].shape[2] == tensors[1].shape[1]) {
|
||||
boxes_index = 1;
|
||||
scores_index = 0;
|
||||
} else {
|
||||
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
|
||||
"4], [batch, classes_num, boxes_num]"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
// do decode and nms
|
||||
ppdet_decoder_.DecodeAndNMS(tensors, results);
|
||||
|
||||
PaddleMultiClassNMS nms;
|
||||
nms.background_label = -1;
|
||||
nms.keep_top_k = 100;
|
||||
nms.nms_eta = 1.0;
|
||||
nms.nms_threshold = 0.5;
|
||||
nms.score_threshold = 0.3;
|
||||
nms.nms_top_k = 1000;
|
||||
nms.normalized = true;
|
||||
nms.Compute(static_cast<const float*>(tensors[boxes_index].Data()),
|
||||
static_cast<const float*>(tensors[scores_index].Data()),
|
||||
tensors[boxes_index].shape, tensors[scores_index].shape);
|
||||
|
||||
auto num_boxes = nms.out_num_rois_data;
|
||||
auto box_data = static_cast<const float*>(nms.out_box_data.data());
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(
|
||||
static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
// do scale
|
||||
if (GetScaleFactor()[0] != 0) {
|
||||
for (auto& result : *results) {
|
||||
for (auto& box : result.boxes) {
|
||||
box[0] /= GetScaleFactor()[1];
|
||||
box[1] /= GetScaleFactor()[0];
|
||||
box[2] /= GetScaleFactor()[1];
|
||||
box[3] /= GetScaleFactor()[0];
|
||||
}
|
||||
}
|
||||
offset += (num_boxes[i] * 6);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<float> PaddleDetPostprocessor::GetScaleFactor() {
|
||||
return scale_factor_;
|
||||
}
|
||||
|
||||
void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value) {
|
||||
for (int i = 0; i < scale_factor_.size(); ++i) {
|
||||
scale_factor_[i] = scale_factor_value[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::DecodeAndNMSApplied() {
|
||||
return apply_decode_and_nms_;
|
||||
}
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -15,16 +15,25 @@
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/ppdet_decode.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace detection {
|
||||
/*! @brief Postprocessor object for PaddleDet serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
public:
|
||||
PaddleDetPostprocessor() = default;
|
||||
|
||||
/** \brief Create a preprocessor instance for PaddleDet serials model
|
||||
*
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
*/
|
||||
explicit PaddleDetPostprocessor(const std::string& config_file)
|
||||
: ppdet_decoder_(config_file) {}
|
||||
|
||||
/** \brief Process the result of runtime and fill to ClassifyResult structure
|
||||
*
|
||||
* \param[in] tensors The inference result from runtime
|
||||
@@ -36,24 +45,29 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
|
||||
/// Apply box decoding and nms step for the outputs for the model.This is
|
||||
/// only available for those model exported without box decoding and nms.
|
||||
void ApplyDecodeAndNMS();
|
||||
void ApplyDecodeAndNMS(const NMSOption& option = NMSOption()) {
|
||||
apply_decode_and_nms_ = true;
|
||||
ppdet_decoder_.SetNMSOption(option);
|
||||
}
|
||||
|
||||
bool DecodeAndNMSApplied();
|
||||
|
||||
/// Set scale_factor_ value.This is only available for those model exported
|
||||
/// without box decoding and nms.
|
||||
void SetScaleFactor(float* scale_factor_value);
|
||||
// Set scale_factor_ value.This is only available for those model exported
|
||||
// without box decoding and nms.
|
||||
void SetScaleFactor(const std::vector<float>& scale_factor_value) {
|
||||
scale_factor_ = scale_factor_value;
|
||||
}
|
||||
|
||||
private:
|
||||
// for model without decode and nms.
|
||||
bool apply_decode_and_nms_ = false;
|
||||
bool DecodeAndNMSApplied() const { return apply_decode_and_nms_; }
|
||||
bool ProcessUnDecodeResults(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
PPDetDecode ppdet_decoder_;
|
||||
std::vector<float> scale_factor_{0.0, 0.0};
|
||||
std::vector<float> GetScaleFactor() { return scale_factor_; }
|
||||
// Process mask tensor for MaskRCNN
|
||||
bool ProcessMask(const FDTensor& tensor,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
||||
bool apply_decode_and_nms_ = false;
|
||||
std::vector<float> scale_factor_{1.0, 1.0};
|
||||
std::vector<float> GetScaleFactor();
|
||||
bool ProcessUnDecodeResults(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
|
296
fastdeploy/vision/detection/ppdet/ppdet_decode.cc
Normal file
296
fastdeploy/vision/detection/ppdet/ppdet_decode.cc
Normal file
@@ -0,0 +1,296 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "ppdet_decode.h"
|
||||
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
PPDetDecode::PPDetDecode(const std::string& config_file) {
|
||||
config_file_ = config_file;
|
||||
ReadPostprocessConfigFromYaml();
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name ReadPostprocessConfigFromYaml
|
||||
* @brief Read decode config from yaml.
|
||||
* @note read arch
|
||||
* read fpn_stride
|
||||
* read nms_threshold on NMS
|
||||
* read score_threshold on NMS
|
||||
* read target_size
|
||||
***************************************************************/
|
||||
bool PPDetDecode::ReadPostprocessConfigFromYaml() {
|
||||
YAML::Node config;
|
||||
try {
|
||||
config = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (config["arch"].IsDefined()) {
|
||||
arch_ = config["arch"].as<std::string>();
|
||||
} else {
|
||||
FDERROR << "Please set model arch,"
|
||||
<< "support value : YOLO, SSD, RetinaNet, RCNN, Face." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (config["fpn_stride"].IsDefined()) {
|
||||
fpn_stride_ = config["fpn_stride"].as<std::vector<int>>();
|
||||
}
|
||||
|
||||
if (config["NMS"].IsDefined()) {
|
||||
for (const auto& op : config["NMS"]) {
|
||||
if (config["background_label"].IsDefined()) {
|
||||
multi_class_nms_.background_label =
|
||||
op["background_label"].as<int64_t>();
|
||||
}
|
||||
if (config["keep_top_k"].IsDefined()) {
|
||||
multi_class_nms_.keep_top_k = op["keep_top_k"].as<int64_t>();
|
||||
}
|
||||
if (config["nms_eta"].IsDefined()) {
|
||||
multi_class_nms_.nms_eta = op["nms_eta"].as<float>();
|
||||
}
|
||||
if (config["nms_threshold"].IsDefined()) {
|
||||
multi_class_nms_.nms_threshold = op["nms_threshold"].as<float>();
|
||||
}
|
||||
if (config["nms_top_k"].IsDefined()) {
|
||||
multi_class_nms_.nms_top_k = op["nms_top_k"].as<int64_t>();
|
||||
}
|
||||
if (config["normalized"].IsDefined()) {
|
||||
multi_class_nms_.normalized = op["normalized"].as<bool>();
|
||||
}
|
||||
if (config["score_threshold"].IsDefined()) {
|
||||
multi_class_nms_.score_threshold = op["score_threshold"].as<float>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (config["Preprocess"].IsDefined()) {
|
||||
for (const auto& op : config["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "Resize") {
|
||||
im_shape_ = op["target_size"].as<std::vector<float>>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name DecodeAndNMS
|
||||
* @brief Read batch and call different decode functions.
|
||||
* @param tensors: model output tensor
|
||||
* results: detection results
|
||||
* @note Only support arch is Picodet.
|
||||
***************************************************************/
|
||||
bool PPDetDecode::DecodeAndNMS(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (tensors.size() == 2) {
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
if (tensors[0].shape[1] == tensors[1].shape[2]) {
|
||||
boxes_index = 0;
|
||||
scores_index = 1;
|
||||
} else if (tensors[0].shape[2] == tensors[1].shape[1]) {
|
||||
boxes_index = 1;
|
||||
scores_index = 0;
|
||||
} else {
|
||||
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
|
||||
"4], [batch, classes_num, boxes_num]"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
multi_class_nms_.Compute(
|
||||
static_cast<const float*>(tensors[boxes_index].Data()),
|
||||
static_cast<const float*>(tensors[scores_index].Data()),
|
||||
tensors[boxes_index].shape, tensors[scores_index].shape);
|
||||
auto num_boxes = multi_class_nms_.out_num_rois_data;
|
||||
auto box_data =
|
||||
static_cast<const float*>(multi_class_nms_.out_box_data.data());
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(
|
||||
static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
}
|
||||
offset += (num_boxes[i] * 6);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
FDASSERT(tensors.size() == fpn_stride_.size() * 2,
|
||||
"The size of output must be fpn_stride * 2.")
|
||||
batchs_ = static_cast<int>(tensors[0].shape[0]);
|
||||
if (arch_ == "PicoDet") {
|
||||
int num_class, reg_max;
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
if (i == 0) {
|
||||
num_class = static_cast<int>(tensors[i].Shape()[2]);
|
||||
}
|
||||
if (i == fpn_stride_.size()) {
|
||||
reg_max = static_cast<int>(tensors[i].Shape()[2] / 4);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < results->size(); ++i) {
|
||||
PicoDetPostProcess(tensors, results, reg_max, num_class);
|
||||
}
|
||||
} else {
|
||||
FDERROR << "ProcessUnDecodeResults only supported when arch is PicoDet."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name PicoDetPostProcess
|
||||
* @brief Do decode and NMS for Picodet.
|
||||
* @param outs: model output tensor
|
||||
* results: detection results
|
||||
* @note Only support PPYOLOE and Picodet.
|
||||
***************************************************************/
|
||||
bool PPDetDecode::PicoDetPostProcess(const std::vector<FDTensor>& outs,
|
||||
std::vector<DetectionResult>* results,
|
||||
int reg_max, int num_class) {
|
||||
for (int batch = 0; batch < batchs_; ++batch) {
|
||||
auto& result = (*results)[batch];
|
||||
result.Clear();
|
||||
for (int i = batch * batchs_ * fpn_stride_.size();
|
||||
i < fpn_stride_.size() * (batch + 1); ++i) {
|
||||
int feature_h =
|
||||
std::ceil(im_shape_[0] / static_cast<float>(fpn_stride_[i]));
|
||||
int feature_w =
|
||||
std::ceil(im_shape_[1] / static_cast<float>(fpn_stride_[i]));
|
||||
for (int idx = 0; idx < feature_h * feature_w; idx++) {
|
||||
const auto* scores =
|
||||
static_cast<const float*>(outs[i].Data()) + (idx * num_class);
|
||||
int row = idx / feature_w;
|
||||
int col = idx % feature_w;
|
||||
float score = 0;
|
||||
int cur_label = 0;
|
||||
for (int label = 0; label < num_class; label++) {
|
||||
if (scores[label] > score) {
|
||||
score = scores[label];
|
||||
cur_label = label;
|
||||
}
|
||||
}
|
||||
if (score > multi_class_nms_.score_threshold) {
|
||||
const auto* bbox_pred =
|
||||
static_cast<const float*>(outs[i + fpn_stride_.size()].Data()) +
|
||||
(idx * 4 * (reg_max));
|
||||
DisPred2Bbox(bbox_pred, cur_label, score, col, row, fpn_stride_[i],
|
||||
&result, reg_max, num_class);
|
||||
}
|
||||
}
|
||||
}
|
||||
fastdeploy::vision::utils::NMS(&result, multi_class_nms_.nms_threshold);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name FastExp
|
||||
* @brief Do exp op
|
||||
* @param x: input data
|
||||
* @return float
|
||||
***************************************************************/
|
||||
float FastExp(float x) {
|
||||
union {
|
||||
uint32_t i;
|
||||
float f;
|
||||
} v{};
|
||||
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
|
||||
return v.f;
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name ActivationFunctionSoftmax
|
||||
* @brief Do Softmax with reg_max.
|
||||
* @param src: input data
|
||||
* dst: output data
|
||||
* @return float
|
||||
***************************************************************/
|
||||
int PPDetDecode::ActivationFunctionSoftmax(const float* src, float* dst,
|
||||
int reg_max) {
|
||||
const float alpha = *std::max_element(src, src + reg_max);
|
||||
float denominator{0};
|
||||
|
||||
for (int i = 0; i < reg_max; ++i) {
|
||||
dst[i] = FastExp(src[i] - alpha);
|
||||
denominator += dst[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < reg_max; ++i) {
|
||||
dst[i] /= denominator;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/***************************************************************
|
||||
* @name DisPred2Bbox
|
||||
* @brief Do Decode.
|
||||
* @param dfl_det: detection data
|
||||
* label: label id
|
||||
* score: confidence
|
||||
* x: col
|
||||
* y: row
|
||||
* stride: stride
|
||||
* results: detection results
|
||||
***************************************************************/
|
||||
void PPDetDecode::DisPred2Bbox(const float*& dfl_det, int label, float score,
|
||||
int x, int y, int stride,
|
||||
fastdeploy::vision::DetectionResult* results,
|
||||
int reg_max, int num_class) {
|
||||
float ct_x = static_cast<float>(x + 0.5) * static_cast<float>(stride);
|
||||
float ct_y = static_cast<float>(y + 0.5) * static_cast<float>(stride);
|
||||
std::vector<float> dis_pred{0, 0, 0, 0};
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float dis = 0;
|
||||
auto* dis_after_sm = new float[reg_max];
|
||||
ActivationFunctionSoftmax(dfl_det + i * (reg_max), dis_after_sm, reg_max);
|
||||
for (int j = 0; j < reg_max; j++) {
|
||||
dis += static_cast<float>(j) * dis_after_sm[j];
|
||||
}
|
||||
dis *= static_cast<float>(stride);
|
||||
dis_pred[i] = dis;
|
||||
delete[] dis_after_sm;
|
||||
}
|
||||
float xmin = (float)(std::max)(ct_x - dis_pred[0], .0f);
|
||||
float ymin = (float)(std::max)(ct_y - dis_pred[1], .0f);
|
||||
float xmax = (float)(std::min)(ct_x + dis_pred[2], (float)im_shape_[0]);
|
||||
float ymax = (float)(std::min)(ct_y + dis_pred[3], (float)im_shape_[1]);
|
||||
|
||||
results->boxes.emplace_back(std::array<float, 4>{xmin, ymin, xmax, ymax});
|
||||
results->label_ids.emplace_back(label);
|
||||
results->scores.emplace_back(score);
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
50
fastdeploy/vision/detection/ppdet/ppdet_decode.h
Normal file
50
fastdeploy/vision/detection/ppdet/ppdet_decode.h
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
class FASTDEPLOY_DECL PPDetDecode {
|
||||
public:
|
||||
PPDetDecode() = default;
|
||||
explicit PPDetDecode(const std::string& config_file);
|
||||
bool DecodeAndNMS(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
void SetNMSOption(const NMSOption& option = NMSOption()) {
|
||||
multi_class_nms_.SetNMSOption(option);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string config_file_;
|
||||
std::string arch_;
|
||||
std::vector<int> fpn_stride_{8, 16, 32, 64};
|
||||
std::vector<float> im_shape_{416, 416};
|
||||
int batchs_ = 1;
|
||||
bool ReadPostprocessConfigFromYaml();
|
||||
void DisPred2Bbox(const float*& dfl_det, int label, float score, int x, int y,
|
||||
int stride, fastdeploy::vision::DetectionResult* results,
|
||||
int reg_max, int num_class);
|
||||
bool PicoDetPostProcess(const std::vector<FDTensor>& outs,
|
||||
std::vector<DetectionResult>* results, int reg_max,
|
||||
int num_class);
|
||||
int ActivationFunctionSoftmax(const float* src, float* dst, int reg_max);
|
||||
PaddleMultiClassNMS multi_class_nms_;
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -18,52 +18,84 @@ void BindPPDet(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::PaddleDetPreprocessor>(
|
||||
m, "PaddleDetPreprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run", [](vision::detection::PaddleDetPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
.def("run",
|
||||
[](vision::detection::PaddleDetPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in PaddleDetPreprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"PaddleDetPreprocessor.");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("disable_normalize", [](vision::detection::PaddleDetPreprocessor& self) {
|
||||
.def("disable_normalize",
|
||||
[](vision::detection::PaddleDetPreprocessor& self) {
|
||||
self.DisableNormalize();
|
||||
})
|
||||
.def("disable_permute", [](vision::detection::PaddleDetPreprocessor& self) {
|
||||
.def("disable_permute",
|
||||
[](vision::detection::PaddleDetPreprocessor& self) {
|
||||
self.DisablePermute();
|
||||
});;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::NMSOption>(m, "NMSOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("background_label",
|
||||
&vision::detection::NMSOption::background_label)
|
||||
.def_readwrite("keep_top_k", &vision::detection::NMSOption::keep_top_k)
|
||||
.def_readwrite("nms_eta", &vision::detection::NMSOption::nms_eta)
|
||||
.def_readwrite("nms_threshold",
|
||||
&vision::detection::NMSOption::nms_threshold)
|
||||
.def_readwrite("nms_top_k", &vision::detection::NMSOption::nms_top_k)
|
||||
.def_readwrite("normalized", &vision::detection::NMSOption::normalized)
|
||||
.def_readwrite("score_threshold",
|
||||
&vision::detection::NMSOption::score_threshold);
|
||||
|
||||
pybind11::class_<vision::detection::PaddleDetPostprocessor>(
|
||||
m, "PaddleDetPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<FDTensor>& inputs) {
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run",
|
||||
[](vision::detection::PaddleDetPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<vision::DetectionResult> results;
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in PaddleDetPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"PaddleDetPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("apply_decode_and_nms",
|
||||
[](vision::detection::PaddleDetPostprocessor& self){
|
||||
self.ApplyDecodeAndNMS();
|
||||
})
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) {
|
||||
.def(
|
||||
"apply_decode_and_nms",
|
||||
[](vision::detection::PaddleDetPostprocessor& self,
|
||||
vision::detection::NMSOption option) {
|
||||
self.ApplyDecodeAndNMS(option);
|
||||
},
|
||||
"A function which adds two numbers",
|
||||
pybind11::arg("option") = vision::detection::NMSOption())
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self,
|
||||
std::vector<pybind11::array>& input_array) {
|
||||
std::vector<vision::DetectionResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in PaddleDetPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"PaddleDetPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PPDetBase, FastDeployModel>(m, "PPDetBase")
|
||||
pybind11::class_<vision::detection::PPDetBase, FastDeployModel>(m,
|
||||
"PPDetBase")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
@@ -74,7 +106,8 @@ void BindPPDet(pybind11::module& m) {
|
||||
return res;
|
||||
})
|
||||
.def("batch_predict",
|
||||
[](vision::detection::PPDetBase& self, std::vector<pybind11::array>& data) {
|
||||
[](vision::detection::PPDetBase& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
@@ -83,94 +116,118 @@ void BindPPDet(pybind11::module& m) {
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
})
|
||||
.def("clone", [](vision::detection::PPDetBase& self) {
|
||||
return self.Clone();
|
||||
})
|
||||
.def_property_readonly("preprocessor", &vision::detection::PPDetBase::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::detection::PPDetBase::GetPostprocessor);
|
||||
.def("clone",
|
||||
[](vision::detection::PPDetBase& self) { return self.Clone(); })
|
||||
.def_property_readonly("preprocessor",
|
||||
&vision::detection::PPDetBase::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor",
|
||||
&vision::detection::PPDetBase::GetPostprocessor);
|
||||
|
||||
pybind11::class_<vision::detection::PPDetDecode>(m, "PPDetDecode")
|
||||
.def(pybind11::init<std::string>());
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLO, vision::detection::PPDetBase>(m, "PPYOLO")
|
||||
pybind11::class_<vision::detection::PPYOLO, vision::detection::PPDetBase>(
|
||||
m, "PPYOLO")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLOE, vision::detection::PPDetBase>(m, "PPYOLOE")
|
||||
pybind11::class_<vision::detection::PPYOLOE, vision::detection::PPDetBase>(
|
||||
m, "PPYOLOE")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PicoDet, vision::detection::PPDetBase>(m, "PicoDet")
|
||||
pybind11::class_<vision::detection::PicoDet, vision::detection::PPDetBase>(
|
||||
m, "PicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOX, vision::detection::PPDetBase>(m, "PaddleYOLOX")
|
||||
pybind11::class_<vision::detection::PaddleYOLOX,
|
||||
vision::detection::PPDetBase>(m, "PaddleYOLOX")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::FasterRCNN, vision::detection::PPDetBase>(m, "FasterRCNN")
|
||||
pybind11::class_<vision::detection::FasterRCNN, vision::detection::PPDetBase>(
|
||||
m, "FasterRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::YOLOv3, vision::detection::PPDetBase>(m, "YOLOv3")
|
||||
pybind11::class_<vision::detection::YOLOv3, vision::detection::PPDetBase>(
|
||||
m, "YOLOv3")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::MaskRCNN, vision::detection::PPDetBase>(m, "MaskRCNN")
|
||||
pybind11::class_<vision::detection::MaskRCNN, vision::detection::PPDetBase>(
|
||||
m, "MaskRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::SSD, vision::detection::PPDetBase>(m, "SSD")
|
||||
pybind11::class_<vision::detection::SSD, vision::detection::PPDetBase>(m,
|
||||
"SSD")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOv5, vision::detection::PPDetBase>(m, "PaddleYOLOv5")
|
||||
pybind11::class_<vision::detection::PaddleYOLOv5,
|
||||
vision::detection::PPDetBase>(m, "PaddleYOLOv5")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOv6, vision::detection::PPDetBase>(m, "PaddleYOLOv6")
|
||||
pybind11::class_<vision::detection::PaddleYOLOv6,
|
||||
vision::detection::PPDetBase>(m, "PaddleYOLOv6")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOv7, vision::detection::PPDetBase>(m, "PaddleYOLOv7")
|
||||
pybind11::class_<vision::detection::PaddleYOLOv7,
|
||||
vision::detection::PPDetBase>(m, "PaddleYOLOv7")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOv8, vision::detection::PPDetBase>(m, "PaddleYOLOv8")
|
||||
pybind11::class_<vision::detection::PaddleYOLOv8,
|
||||
vision::detection::PPDetBase>(m, "PaddleYOLOv8")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::RTMDet, vision::detection::PPDetBase>(m, "RTMDet")
|
||||
pybind11::class_<vision::detection::RTMDet, vision::detection::PPDetBase>(
|
||||
m, "RTMDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::CascadeRCNN, vision::detection::PPDetBase>(m, "CascadeRCNN")
|
||||
pybind11::class_<vision::detection::CascadeRCNN,
|
||||
vision::detection::PPDetBase>(m, "CascadeRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PSSDet, vision::detection::PPDetBase>(m, "PSSDet")
|
||||
pybind11::class_<vision::detection::PSSDet, vision::detection::PPDetBase>(
|
||||
m, "PSSDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::RetinaNet, vision::detection::PPDetBase>(m, "RetinaNet")
|
||||
pybind11::class_<vision::detection::RetinaNet, vision::detection::PPDetBase>(
|
||||
m, "RetinaNet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLOESOD, vision::detection::PPDetBase>(m, "PPYOLOESOD")
|
||||
pybind11::class_<vision::detection::PPYOLOESOD, vision::detection::PPDetBase>(
|
||||
m, "PPYOLOESOD")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::FCOS, vision::detection::PPDetBase>(m, "FCOS")
|
||||
pybind11::class_<vision::detection::FCOS, vision::detection::PPDetBase>(
|
||||
m, "FCOS")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::TTFNet, vision::detection::PPDetBase>(m, "TTFNet")
|
||||
pybind11::class_<vision::detection::TTFNet, vision::detection::PPDetBase>(
|
||||
m, "TTFNet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::TOOD, vision::detection::PPDetBase>(m, "TOOD")
|
||||
pybind11::class_<vision::detection::TOOD, vision::detection::PPDetBase>(
|
||||
m, "TOOD")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::GFL, vision::detection::PPDetBase>(m, "GFL")
|
||||
pybind11::class_<vision::detection::GFL, vision::detection::PPDetBase>(m,
|
||||
"GFL")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
}
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
|
||||
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "fastdeploy/function/pad.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
@@ -126,7 +127,7 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
if (images->empty()) {
|
||||
FDERROR << "The size of input images should be greater than 0."
|
||||
<< std::endl;
|
||||
return false;
|
||||
@@ -146,9 +147,9 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
||||
// All the tensor will pad to the max size to compose a batched tensor
|
||||
std::vector<int> max_hw({-1, -1});
|
||||
|
||||
float* scale_factor_ptr =
|
||||
auto* scale_factor_ptr =
|
||||
reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||
float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
int origin_w = (*images)[i].Width();
|
||||
int origin_h = (*images)[i].Height();
|
||||
@@ -208,16 +209,20 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
|
||||
}
|
||||
void PaddleDetPreprocessor::DisableNormalize() {
|
||||
this->disable_normalize_ = true;
|
||||
// the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing
|
||||
// the DisableNormalize function will be invalid if the configuration file is
|
||||
// loaded during preprocessing
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
void PaddleDetPreprocessor::DisablePermute() {
|
||||
this->disable_permute_ = true;
|
||||
// the DisablePermute function will be invalid if the configuration file is loaded during preprocessing
|
||||
// the DisablePermute function will be invalid if the configuration file is
|
||||
// loaded during preprocessing
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
} // namespace detection
|
||||
|
@@ -49,6 +49,15 @@ class PaddleDetPreprocessor:
|
||||
self._preprocessor.disable_permute()
|
||||
|
||||
|
||||
class NMSOption:
|
||||
def __init__(self):
|
||||
self.nms_option = C.vision.detection.NMSOption()
|
||||
|
||||
@property
|
||||
def background_label(self):
|
||||
return self.nms_option.background_label
|
||||
|
||||
|
||||
class PaddleDetPostprocessor:
|
||||
def __init__(self):
|
||||
"""Create a postprocessor for PaddleDetection Model
|
||||
@@ -64,10 +73,12 @@ class PaddleDetPostprocessor:
|
||||
"""
|
||||
return self._postprocessor.run(runtime_results)
|
||||
|
||||
def apply_decode_and_nms(self):
|
||||
def apply_decode_and_nms(self, nms_option=None):
|
||||
"""This function will enable decode and nms in postprocess step.
|
||||
"""
|
||||
return self._postprocessor.apply_decode_and_nms()
|
||||
if nms_option is None:
|
||||
nms_option = NMSOption()
|
||||
self._postprocessor.ApplyDecodeAndNMS(self, nms_option.nms_option)
|
||||
|
||||
|
||||
class PPYOLOE(FastDeployModel):
|
||||
@@ -734,7 +745,7 @@ class GFL(PPYOLOE):
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.PADDLE, "GFL model only support model format of ModelFormat.Paddle now."
|
||||
self._model = C.vision.detection.GFL(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
self._model = C.vision.detection.GFL(model_file, params_file,
|
||||
config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "GFL model initialize failed."
|
Reference in New Issue
Block a user