[Model] Refactor PaddleDetection module (#575)

* Add namespace for functions

* Refactor PaddleDetection module

* finish all the single image test

* Update preprocessor.cc

* fix some litte detail

* add python api

* Update postprocessor.cc
This commit is contained in:
Jason
2022-11-15 10:43:23 +08:00
committed by GitHub
parent aa21272eaa
commit beaa0fd190
39 changed files with 1282 additions and 1438 deletions

View File

@@ -66,4 +66,3 @@ print(result)
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im) cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg") print("Visualized result save in ./visualized_result.jpg")
print(runtime_option)

View File

@@ -17,7 +17,7 @@
namespace fastdeploy { namespace fastdeploy {
namespace pipeline { namespace pipeline {
PPTinyPose::PPTinyPose( PPTinyPose::PPTinyPose(
fastdeploy::vision::detection::PPYOLOE* det_model, fastdeploy::vision::detection::PicoDet* det_model,
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model) fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model)
: detector_(det_model), pptinypose_model_(pptinypose_model) {} : detector_(det_model), pptinypose_model_(pptinypose_model) {}

View File

@@ -35,7 +35,7 @@ class FASTDEPLOY_DECL PPTinyPose {
* \param[in] pptinypose_model Initialized pptinypose model object * \param[in] pptinypose_model Initialized pptinypose model object
*/ */
PPTinyPose( PPTinyPose(
fastdeploy::vision::detection::PPYOLOE* det_model, fastdeploy::vision::detection::PicoDet* det_model,
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model); fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model);
/** \brief Predict the keypoint detection result for an input image /** \brief Predict the keypoint detection result for an input image
@@ -52,7 +52,7 @@ class FASTDEPLOY_DECL PPTinyPose {
float detection_model_score_threshold = 0; float detection_model_score_threshold = 0;
protected: protected:
fastdeploy::vision::detection::PPYOLOE* detector_ = nullptr; fastdeploy::vision::detection::PicoDet* detector_ = nullptr;
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model_ = fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model_ =
nullptr; nullptr;

View File

@@ -18,31 +18,8 @@ namespace fastdeploy {
void BindPPTinyPosePipeline(pybind11::module& m) { void BindPPTinyPosePipeline(pybind11::module& m) {
pybind11::class_<pipeline::PPTinyPose>(m, "PPTinyPose") pybind11::class_<pipeline::PPTinyPose>(m, "PPTinyPose")
// explicitly pybind more kinds of detection models here
.def(pybind11::init<fastdeploy::vision::detection::PPYOLOE*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::PicoDet*, .def(pybind11::init<fastdeploy::vision::detection::PicoDet*,
fastdeploy::vision::keypointdetection::PPTinyPose*>()) fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::PPYOLO*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::PPYOLOv2*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::PaddleYOLOX*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::FasterRCNN*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::YOLOv3*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def(pybind11::init<fastdeploy::vision::detection::MaskRCNN*,
fastdeploy::vision::keypointdetection::PPTinyPose*>())
.def("predict", [](pipeline::PPTinyPose& self, .def("predict", [](pipeline::PPTinyPose& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);

View File

@@ -29,7 +29,6 @@
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/detection/contrib/rknpu2/model.h"
#include "fastdeploy/vision/facedet/contrib/retinaface.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h"
#include "fastdeploy/vision/facedet/contrib/scrfd.h" #include "fastdeploy/vision/facedet/contrib/scrfd.h"
#include "fastdeploy/vision/facedet/contrib/ultraface.h" #include "fastdeploy/vision/facedet/contrib/ultraface.h"

View File

@@ -1,16 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"

View File

@@ -1,29 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindRKDet(pybind11::module& m) {
pybind11::class_<vision::detection::RKPicoDet, FastDeployModel>(m, "RKPicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::detection::RKPicoDet& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
}
} // namespace fastdeploy

View File

@@ -1,201 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
RKPicoDet::RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT};
valid_rknpu_backends = {Backend::RKNPU2};
if ((model_format == ModelFormat::RKNN) ||
(model_format == ModelFormat::ONNX)) {
has_nms_ = false;
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
// NMS parameters come from RKPicoDet_s_nms
background_label = -1;
keep_top_k = 100;
nms_eta = 1;
nms_threshold = 0.5;
nms_top_k = 1000;
normalized = true;
score_threshold = 0.3;
initialized = Initialize();
}
bool RKPicoDet::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool RKPicoDet::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
Cast::Run(mat, "float");
scale_factor.resize(2);
scale_factor[0] = mat->Height() * 1.0 / origin_h;
scale_factor[1] = mat->Width() * 1.0 / origin_w;
outputs->resize(1);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
return true;
}
bool RKPicoDet::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
continue;
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
continue;
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
return true;
}
bool RKPicoDet::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
if (!has_nms_) {
int boxes_index = 0;
int scores_index = 1;
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = background_label;
nms.keep_top_k = keep_top_k;
nms.nms_eta = nms_eta;
nms.nms_threshold = nms_threshold;
nms.score_threshold = score_threshold;
nms.nms_top_k = nms_top_k;
nms.normalized = normalized;
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
static_cast<float*>(infer_result[scores_index].Data()),
infer_result[boxes_index].shape,
infer_result[scores_index].shape);
if (nms.out_num_rois_data[0] > 0) {
result->Reserve(nms.out_num_rois_data[0]);
}
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
result->label_ids.push_back(nms.out_box_data[i * 6]);
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{nms.out_box_data[i * 6 + 2] / scale_factor[1],
nms.out_box_data[i * 6 + 3] / scale_factor[0],
nms.out_box_data[i * 6 + 4] / scale_factor[1],
nms.out_box_data[i * 6 + 5] / scale_factor[0]});
}
} else {
FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,46 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE {
public:
RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::RKNN);
virtual std::string ModelName() const { return "RKPicoDet"; }
protected:
/// Build the preprocess pipeline from the loaded model
virtual bool BuildPreprocessPipelineFromConfig();
/// Preprocess an input image, and set the preprocessed results to `outputs`
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
/// Postprocess the inferenced results, and set the final result to `result`
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result);
virtual bool Initialize();
private:
std::vector<float> scale_factor{1.0, 1.0};
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -27,8 +27,6 @@ void BindNanoDetPlus(pybind11::module& m);
void BindPPDet(pybind11::module& m); void BindPPDet(pybind11::module& m);
void BindYOLOv7End2EndTRT(pybind11::module& m); void BindYOLOv7End2EndTRT(pybind11::module& m);
void BindYOLOv7End2EndORT(pybind11::module& m); void BindYOLOv7End2EndORT(pybind11::module& m);
void BindRKDet(pybind11::module& m);
void BindDetection(pybind11::module& m) { void BindDetection(pybind11::module& m) {
auto detection_module = auto detection_module =
@@ -44,6 +42,5 @@ void BindDetection(pybind11::module& m) {
BindNanoDetPlus(detection_module); BindNanoDetPlus(detection_module);
BindYOLOv7End2EndTRT(detection_module); BindYOLOv7End2EndTRT(detection_module);
BindYOLOv7End2EndORT(detection_module); BindYOLOv7End2EndORT(detection_module);
BindRKDet(detection_module);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,68 @@
#include "fastdeploy/vision/detection/ppdet/base.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) : preprocessor_(config_file) {
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
}
bool PPDetBase::Initialize() {
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool PPDetBase::Predict(cv::Mat* im, DetectionResult* result) {
return Predict(*im, result);
}
bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) {
std::vector<DetectionResult> results;
if (!BatchPredict({im}, &results)) {
return false;
}
*result = std::move(results[0]);
return true;
}
bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<DetectionResult>* results) {
std::vector<FDMat> fd_images = WrapMat(imgs);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
reused_input_tensors_[0].name = "image";
reused_input_tensors_[1].name = "scale_factor";
reused_input_tensors_[2].name = "im_shape";
// Some models don't need im_shape as input
if (NumInputsOfRuntime() == 2) {
reused_input_tensors_.pop_back();
}
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
return false;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
@@ -26,9 +28,9 @@ namespace vision {
*/ */
namespace detection { namespace detection {
/*! @brief PPYOLOE model object used when to load a PPYOLOE model exported by PaddleDetection /*! @brief Base model object used when to load a model exported by PaddleDetection
*/ */
class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { class FASTDEPLOY_DECL PPDetBase : public FastDeployModel {
public: public:
/** \brief Set path of model file and configuration file, and the configuration of runtime /** \brief Set path of model file and configuration file, and the configuration of runtime
* *
@@ -38,49 +40,49 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is Paddle format * \param[in] model_format Model format of the loaded model, default is Paddle format
*/ */
PPYOLOE(const std::string& model_file, const std::string& params_file, PPDetBase(const std::string& model_file, const std::string& params_file,
const std::string& config_file, const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(), const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
virtual std::string ModelName() const { return "PaddleDetection/PPYOLOE"; } virtual std::string ModelName() const { return "PaddleDetection/BaseModel"; }
/** \brief Predict the detection result for an input image /** \brief DEPRECATED Predict the detection result for an input image
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output detection result will be writen to this structure * \param[in] result The output detection result
* \return true if the prediction successed, otherwise false * \return true if the prediction successed, otherwise false
*/ */
virtual bool Predict(cv::Mat* im, DetectionResult* result); virtual bool Predict(cv::Mat* im, DetectionResult* result);
/** \brief Predict the detection result for an input image
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output detection result
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const cv::Mat& im, DetectionResult* result);
/** \brief Predict the detection result for an input image list
* \param[in] im The input image list, all the elements come from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] results The output detection result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<DetectionResult>* results);
PaddleDetPreprocessor& GetPreprocessor() {
return preprocessor_;
}
PaddleDetPostprocessor& GetPostprocessor() {
return postprocessor_;
}
protected: protected:
PPYOLOE() {}
virtual bool Initialize(); virtual bool Initialize();
/// Build the preprocess pipeline from the loaded model PaddleDetPreprocessor preprocessor_;
virtual bool BuildPreprocessPipelineFromConfig(); PaddleDetPostprocessor postprocessor_;
/// Preprocess an input image, and set the preprocessed results to `outputs`
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
/// Postprocess the inferenced results, and set the final result to `result`
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result);
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
// configuration for nms
int64_t background_label = -1;
int64_t keep_top_k = 300;
float nms_eta = 1.0;
float nms_threshold = 0.7;
float score_threshold = 0.01;
int64_t nms_top_k = 10000;
bool normalized = true;
bool has_nms_ = true;
// This function will used to check if this model contains multiclass_nms
// and get parameters from the operator
void GetNmsInfo();
}; };
} // namespace detection } // namespace detection

View File

@@ -1,120 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h"
namespace fastdeploy {
namespace vision {
namespace detection {
MaskRCNN::MaskRCNN(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool MaskRCNN::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result) {
// index 0: bbox_data [N, 6] float32
// index 1: bbox_num [B=1] int32
// index 2: mask_data [N, h, w] int32
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
FDASSERT(infer_result.size() == 3,
"The infer_result must contains 3 otuput Tensors, but found %lu",
infer_result.size());
FDTensor& box_tensor = infer_result[0];
FDTensor& box_num_tensor = infer_result[1];
FDTensor& mask_tensor = infer_result[2];
int box_num = 0;
if (box_num_tensor.dtype == FDDataType::INT32) {
box_num = *(static_cast<int32_t*>(box_num_tensor.Data()));
} else if (box_num_tensor.dtype == FDDataType::INT64) {
box_num = *(static_cast<int64_t*>(box_num_tensor.Data()));
} else {
FDASSERT(false,
"The output box_num of PaddleDetection/MaskRCNN model should be "
"type of int32/int64.");
}
if (box_num <= 0) {
return true; // no object detected.
}
result->Resize(box_num);
float* box_data = static_cast<float*>(box_tensor.Data());
for (size_t i = 0; i < box_num; ++i) {
result->label_ids[i] = static_cast<int>(box_data[i * 6]);
result->scores[i] = box_data[i * 6 + 1];
result->boxes[i] =
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
box_data[i * 6 + 4], box_data[i * 6 + 5]};
}
result->contain_masks = true;
// TODO(qiuyanjun): Cast int64/int8 to int32.
FDASSERT(mask_tensor.dtype == FDDataType::INT32,
"The dtype of mask Tensor must be int32 now!");
// In PaddleDetection/MaskRCNN, the mask_h and mask_w
// are already aligned with original input image. So,
// we need to crop it from output mask according to
// the detected bounding box.
// +-----------------------+
// | x1,y1 |
// | +---------------+ |
// | | | |
// | | Crop | |
// | | | |
// | | | |
// | +---------------+ |
// | x2,y2 |
// +-----------------------+
int64_t out_mask_h = mask_tensor.shape[1];
int64_t out_mask_w = mask_tensor.shape[2];
int64_t out_mask_numel = out_mask_h * out_mask_w;
int32_t* out_mask_data = static_cast<int32_t*>(mask_tensor.Data());
for (size_t i = 0; i < box_num; ++i) {
// crop instance mask according to box
int64_t x1 = static_cast<int64_t>(result->boxes[i][0]);
int64_t y1 = static_cast<int64_t>(result->boxes[i][1]);
int64_t x2 = static_cast<int64_t>(result->boxes[i][2]);
int64_t y2 = static_cast<int64_t>(result->boxes[i][3]);
int64_t keep_mask_h = y2 - y1;
int64_t keep_mask_w = x2 - x1;
int64_t keep_mask_numel = keep_mask_h * keep_mask_w;
result->masks[i].Resize(keep_mask_numel); // int32
result->masks[i].shape = {keep_mask_h, keep_mask_w};
int32_t* mask_start_ptr = out_mask_data + i * out_mask_numel;
int32_t* keep_mask_ptr = static_cast<int32_t*>(result->masks[i].Data());
for (size_t row = y1; row < y2; ++row) {
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
int32_t* out_row_start_ptr = mask_start_ptr + row * out_mask_w + x1;
int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w;
std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col);
}
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -13,10 +13,152 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h" #include "fastdeploy/vision/detection/ppdet/base.h"
#include "fastdeploy/vision/detection/ppdet/picodet.h"
#include "fastdeploy/vision/detection/ppdet/ppyolo.h" namespace fastdeploy {
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" namespace vision {
#include "fastdeploy/vision/detection/ppdet/rcnn.h" namespace detection {
#include "fastdeploy/vision/detection/ppdet/yolov3.h"
#include "fastdeploy/vision/detection/ppdet/yolox.h" class FASTDEPLOY_DECL PicoDet : public PPDetBase {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g picodet/model.pdmodel
* \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
PicoDet(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PicoDet"; }
};
class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g ppyoloe/model.pdmodel
* \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
PPYOLOE(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PPYOLOE"; }
};
class FASTDEPLOY_DECL PPYOLO : public PPDetBase {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g ppyolo/model.pdmodel
* \param[in] params_file Path of parameter file, e.g ppyolo/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
PPYOLO(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/PP-YOLO"; }
};
class FASTDEPLOY_DECL YOLOv3 : public PPDetBase {
public:
YOLOv3(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; }
};
class FASTDEPLOY_DECL PaddleYOLOX : public PPDetBase {
public:
PaddleYOLOX(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; }
};
class FASTDEPLOY_DECL FasterRCNN : public PPDetBase {
public:
FasterRCNN(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; }
};
class FASTDEPLOY_DECL MaskRCNN : public PPDetBase {
public:
MaskRCNN(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,66 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/picodet.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
PicoDet::PicoDet(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
background_label = -1;
keep_top_k = 100;
nms_eta = 1;
nms_threshold = 0.6;
nms_top_k = 1000;
normalized = true;
score_threshold = 0.025;
CheckIfContainDecodeAndNMS();
initialized = Initialize();
}
bool PicoDet::CheckIfContainDecodeAndNMS() {
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
if (cfg["arch"].as<std::string>() == "PicoDet") {
FDERROR << "The arch in config file is PicoDet, which means this model "
"doesn contain box decode and nms, please export model with "
"decode and nms."
<< std::endl;
return false;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,36 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL PicoDet : public PPYOLOE {
public:
PicoDet(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
// Only support picodet contains decode and nms
bool CheckIfContainDecodeAndNMS();
virtual std::string ModelName() const { return "PicoDet"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,132 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace detection {
bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<DetectionResult>* results) {
auto shape = tensor.Shape();
if (tensor.Dtype() != FDDataType::INT32) {
FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl;
return false;
}
int64_t out_mask_h = shape[1];
int64_t out_mask_w = shape[2];
int64_t out_mask_numel = shape[1] * shape[2];
const int32_t* data = reinterpret_cast<const int32_t*>(tensor.CpuData());
int index = 0;
for (int i = 0; i < results->size(); ++i) {
(*results)[i].contain_masks = true;
(*results)[i].masks.resize((*results)[i].boxes.size());
for (int j = 0; j < (*results)[i].boxes.size(); ++j) {
int x1 = static_cast<int>((*results)[i].boxes[j][0]);
int y1 = static_cast<int>((*results)[i].boxes[j][1]);
int x2 = static_cast<int>((*results)[i].boxes[j][2]);
int y2 = static_cast<int>((*results)[i].boxes[j][3]);
int keep_mask_h = y2 - y1;
int keep_mask_w = x2 - x1;
int keep_mask_numel = keep_mask_h * keep_mask_w;
(*results)[i].masks[j].Resize(keep_mask_numel);
(*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w};
const int32_t* current_ptr = data + index * out_mask_numel;
int32_t* keep_mask_ptr = reinterpret_cast<int32_t*>((*results)[i].masks[j].Data());
for (int row = y1; row < y2; ++row) {
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1;
int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w;
std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col);
}
index += 1;
}
}
return true;
}
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vector<DetectionResult>* results) {
if (tensors[0].shape[0] == 0) {
// No detected boxes
return true;
}
// Get number of boxes for each input image
std::vector<int> num_boxes(tensors[1].shape[0]);
int total_num_boxes = 0;
if (tensors[1].dtype == FDDataType::INT32) {
const int32_t* data = static_cast<const int32_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
total_num_boxes += num_boxes[i];
}
} else if (tensors[1].dtype == FDDataType::INT64) {
const int64_t* data = static_cast<const int64_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
}
}
// Special case for TensorRT, it has fixed output shape of NMS
// So there's invalid boxes in its' output boxes
int num_output_boxes = tensors[0].Shape()[0];
bool contain_invalid_boxes = false;
if (total_num_boxes != num_output_boxes) {
if (num_output_boxes % num_boxes.size() == 0) {
contain_invalid_boxes = true;
} else {
FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl;
return false;
}
}
// Get boxes for each input image
results->resize(num_boxes.size());
const float* box_data = static_cast<const float*>(tensors[0].CpuData());
int offset = 0;
for (size_t i = 0; i < num_boxes.size(); ++i) {
const float* ptr = box_data + offset;
(*results)[i].Reserve(num_boxes[i]);
for (size_t j = 0; j < num_boxes[i]; ++j) {
(*results)[i].label_ids.push_back(static_cast<int32_t>(round(ptr[j * 6])));
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
(*results)[i].boxes.emplace_back(std::array<float, 4>({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
}
if (contain_invalid_boxes) {
offset += (num_output_boxes * 6 / num_boxes.size());
} else {
offset += (num_boxes[i] * 6);
}
}
// Only detection
if (tensors.size() <= 2) {
return true;
}
if (tensors[2].Shape()[0] != num_output_boxes) {
FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl;
return false;
}
// process for maskrcnn
return ProcessMask(tensors[2], results);
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -13,26 +13,30 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "fastdeploy/vision/detection/ppdet/rcnn.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
namespace detection { namespace detection {
/*! @brief Postprocessor object for PaddleDet serials model.
class FASTDEPLOY_DECL MaskRCNN : public FasterRCNN { */
class FASTDEPLOY_DECL PaddleDetPostprocessor {
public: public:
MaskRCNN(const std::string& model_file, const std::string& params_file, PaddleDetPostprocessor() = default;
const std::string& config_file, /** \brief Process the result of runtime and fill to ClassifyResult structure
const RuntimeOption& custom_option = RuntimeOption(), *
const ModelFormat& model_format = ModelFormat::PADDLE); * \param[in] tensors The inference result from runtime
* \param[in] result The output result of detection
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; } * \return true if the postprocess successed, otherwise false
*/
virtual bool Postprocess(std::vector<FDTensor>& infer_result, bool Run(const std::vector<FDTensor>& tensors,
DetectionResult* result); std::vector<DetectionResult>* result);
private:
protected: // Process mask tensor for MaskRCNN
MaskRCNN() {} bool ProcessMask(const FDTensor& tensor,
std::vector<DetectionResult>* results);
}; };
} // namespace detection } // namespace detection

View File

@@ -15,94 +15,94 @@
namespace fastdeploy { namespace fastdeploy {
void BindPPDet(pybind11::module& m) { void BindPPDet(pybind11::module& m) {
pybind11::class_<vision::detection::PPYOLOE, FastDeployModel>(m, "PPYOLOE") pybind11::class_<vision::detection::PaddleDetPreprocessor>(
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, m, "PaddleDetPreprocessor")
ModelFormat>()) .def(pybind11::init<std::string>())
.def("predict", .def("run", [](vision::detection::PaddleDetPreprocessor& self, std::vector<pybind11::array>& im_list) {
[](vision::detection::PPYOLOE& self, pybind11::array& data) { std::vector<vision::FDMat> images;
auto mat = PyArrayToCvMat(data); for (size_t i = 0; i < im_list.size(); ++i) {
vision::DetectionResult res; images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
self.Predict(&mat, &res); }
return res; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleDetPreprocessor.')");
}
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return outputs;
}); });
pybind11::class_<vision::detection::PPYOLO, FastDeployModel>(m, "PPYOLO") pybind11::class_<vision::detection::PaddleDetPostprocessor>(
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, m, "PaddleDetPostprocessor")
ModelFormat>()) .def(pybind11::init<>())
.def("predict", .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<FDTensor>& inputs) {
[](vision::detection::PPYOLO& self, pybind11::array& data) { std::vector<vision::DetectionResult> results;
auto mat = PyArrayToCvMat(data); if (!self.Run(inputs, &results)) {
vision::DetectionResult res; pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')");
self.Predict(&mat, &res); }
return res; return results;
})
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) {
std::vector<vision::DetectionResult> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) {
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')");
}
return results;
}); });
pybind11::class_<vision::detection::PPYOLOv2, FastDeployModel>(m, "PPYOLOv2") pybind11::class_<vision::detection::PPDetBase, FastDeployModel>(m, "PPDetBase")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def("predict", .def("predict",
[](vision::detection::PPYOLOv2& self, pybind11::array& data) { [](vision::detection::PPDetBase& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
vision::DetectionResult res; vision::DetectionResult res;
self.Predict(&mat, &res); self.Predict(&mat, &res);
return res; return res;
}); })
.def("batch_predict",
[](vision::detection::PPDetBase& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::DetectionResult> results;
self.BatchPredict(images, &results);
return results;
})
.def_property_readonly("preprocessor", &vision::detection::PPDetBase::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::detection::PPDetBase::GetPostprocessor);
pybind11::class_<vision::detection::PicoDet, FastDeployModel>(m, "PicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::detection::PicoDet& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
pybind11::class_<vision::detection::PaddleYOLOX, FastDeployModel>( pybind11::class_<vision::detection::PPYOLO, vision::detection::PPDetBase>(m, "PPYOLO")
m, "PaddleYOLOX")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>());
.def("predict",
[](vision::detection::PaddleYOLOX& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
pybind11::class_<vision::detection::FasterRCNN, FastDeployModel>(m, pybind11::class_<vision::detection::PPYOLOE, vision::detection::PPDetBase>(m, "PPYOLOE")
"FasterRCNN")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>());
.def("predict",
[](vision::detection::FasterRCNN& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
pybind11::class_<vision::detection::YOLOv3, FastDeployModel>(m, "YOLOv3") pybind11::class_<vision::detection::PicoDet, vision::detection::PPDetBase>(m, "PicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>());
.def("predict",
[](vision::detection::YOLOv3& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
pybind11::class_<vision::detection::MaskRCNN, FastDeployModel>(m, "MaskRCNN") pybind11::class_<vision::detection::PaddleYOLOX, vision::detection::PPDetBase>(m, "PaddleYOLOX")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>());
.def("predict",
[](vision::detection::MaskRCNN& self, pybind11::array& data) { pybind11::class_<vision::detection::FasterRCNN, vision::detection::PPDetBase>(m, "FasterRCNN")
auto mat = PyArrayToCvMat(data); .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
vision::DetectionResult res; ModelFormat>());
self.Predict(&mat, &res);
return res; pybind11::class_<vision::detection::YOLOv3, vision::detection::PPDetBase>(m, "YOLOv3")
}); .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
pybind11::class_<vision::detection::MaskRCNN, vision::detection::PPDetBase>(m, "MaskRCNN")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -1,78 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/ppyolo.h"
namespace fastdeploy {
namespace vision {
namespace detection {
PPYOLO::PPYOLO(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
has_nms_ = true;
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool PPYOLO::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool PPYOLO::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = mat->Height() * 1.0 / origin_h;
ptr2[1] = mat->Width() * 1.0 / origin_w;
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,52 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL PPYOLO : public PPYOLOE {
public:
PPYOLO(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
virtual std::string ModelName() const { return "PaddleDetection/PPYOLO"; }
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual bool Initialize();
protected:
PPYOLO() {}
};
class FASTDEPLOY_DECL PPYOLOv2 : public PPYOLO {
public:
PPYOLOv2(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPYOLO(model_file, params_file, config_file, custom_option,
model_format) {}
virtual std::string ModelName() const { return "PaddleDetection/PPYOLOv2"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,274 +0,0 @@
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
#endif
namespace fastdeploy {
namespace vision {
namespace detection {
PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
void PPYOLOE::GetNmsInfo() {
#ifdef ENABLE_PADDLE_FRONTEND
if (runtime_option.model_format == ModelFormat::PADDLE) {
std::string contents;
if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) {
return;
}
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
if (reader.has_nms) {
has_nms_ = true;
background_label = reader.nms_params.background_label;
keep_top_k = reader.nms_params.keep_top_k;
nms_eta = reader.nms_params.nms_eta;
nms_threshold = reader.nms_params.nms_threshold;
score_threshold = reader.nms_params.score_threshold;
nms_top_k = reader.nms_params.nms_top_k;
normalized = reader.nms_params.normalized;
}
}
#endif
}
bool PPYOLOE::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
reused_input_tensors_.resize(2);
return true;
}
bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
bool has_permute = false;
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
has_permute = true;
continue;
// processors_.push_back(std::make_shared<HWC2CHW>());
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
if (has_permute) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
} else {
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
outputs->resize(2);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
ptr[0] = mat->Height() * 1.0 / origin_h;
ptr[1] = mat->Width() * 1.0 / origin_w;
return true;
}
bool PPYOLOE::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
has_nms_ = true;
if (!has_nms_) {
int boxes_index = 0;
int scores_index = 1;
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = background_label;
nms.keep_top_k = keep_top_k;
nms.nms_eta = nms_eta;
nms.nms_threshold = nms_threshold;
nms.score_threshold = score_threshold;
nms.nms_top_k = nms_top_k;
nms.normalized = normalized;
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
static_cast<float*>(infer_result[scores_index].Data()),
infer_result[boxes_index].shape,
infer_result[scores_index].shape);
if (nms.out_num_rois_data[0] > 0) {
result->Reserve(nms.out_num_rois_data[0]);
}
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
result->label_ids.push_back(nms.out_box_data[i * 6]);
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
result->boxes.emplace_back(std::array<float, 4>{
nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3],
nms.out_box_data[i * 6 + 4], nms.out_box_data[i * 6 + 5]});
}
} else {
std::vector<int> num_boxes(infer_result[1].shape[0]);
if (infer_result[1].dtype == FDDataType::INT32) {
int32_t* data = static_cast<int32_t*>(infer_result[1].Data());
for (size_t i = 0; i < infer_result[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
}
} else if (infer_result[1].dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(infer_result[1].Data());
for (size_t i = 0; i < infer_result[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
}
}
// Only support batch = 1 now
result->Reserve(num_boxes[0]);
float* box_data = static_cast<float*>(infer_result[0].Data());
for (size_t i = 0; i < num_boxes[0]; ++i) {
result->label_ids.push_back(box_data[i * 6]);
result->scores.push_back(box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
box_data[i * 6 + 4], box_data[i * 6 + 5]});
}
}
return true;
}
bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) {
Mat mat(*im);
if (!Preprocess(&mat, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess input data while using model:"
<< ModelName() << "." << std::endl;
return false;
}
if (!Infer()) {
FDERROR << "Failed to inference while using model:" << ModelName() << "."
<< std::endl;
return false;
}
if (!Postprocess(reused_output_tensors_, result)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl;
return false;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,201 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/pad.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) {
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor.");
initialized_ = true;
}
bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
bool has_permute = false;
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
has_permute = true;
continue;
// processors_.push_back(std::make_shared<HWC2CHW>());
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
if (has_permute) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
} else {
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
// There are 3 outputs, image, scale_factor, im_shape
// But im_shape is not used for all the PaddleDetection models
// So preprocessor will output the 3 FDTensors, and how to use `im_shape`
// is decided by the model itself
outputs->resize(3);
int batch = static_cast<int>(images->size());
// Allocate memory for scale_factor
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
// Allocate memory for im_shape
(*outputs)[2].Resize({batch, 2}, FDDataType::FP32);
// Record the max size for a batch of input image
// All the tensor will pad to the max size to compose a batched tensor
std::vector<int> max_hw({-1, -1});
float* scale_factor_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData());
float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
for (size_t i = 0; i < images->size(); ++i) {
int origin_w = (*images)[i].Width();
int origin_h = (*images)[i].Height();
scale_factor_ptr[2 * i] = 1.0;
scale_factor_ptr[2 * i + 1] = 1.0;
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
return false;
}
if (processors_[j]->Name().find("Resize") != std::string::npos) {
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
}
}
if ((*images)[i].Height() > max_hw[0]) {
max_hw[0] = (*images)[i].Height();
}
if ((*images)[i].Width() > max_hw[1]) {
max_hw[1] = (*images)[i].Width();
}
im_shape_ptr[2 * i] = max_hw[0];
im_shape_ptr[2 * i + 1] = max_hw[1];
}
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> im_tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
// if the size of image less than max_hw, pad to max_hw
FDTensor tensor;
(*images)[i].ShareWithTensor(&tensor);
function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0);
} else {
// No need pad
(*images)[i].ShareWithTensor(&(im_tensors[i]));
}
// Reshape to 1xCxHxW
im_tensors[i].ExpandDim(0);
}
if (im_tensors.size() == 1) {
// If there's only 1 input, no need to concat
// skip memory copy
(*outputs)[0] = std::move(im_tensors[0]);
} else {
// Else concat the im tensor for each input image
// compose a batched input tensor
function::Concat(im_tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace detection {
/*! @brief Preprocessor object for PaddleDet serials model.
*/
class FASTDEPLOY_DECL PaddleDetPreprocessor {
public:
PaddleDetPreprocessor() = default;
/** \brief Create a preprocessor instance for PaddleDet serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
*/
explicit PaddleDetPreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
private:
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
std::vector<std::shared_ptr<Processor>> processors_;
bool initialized_ = false;
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,84 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/rcnn.h"
namespace fastdeploy {
namespace vision {
namespace detection {
FasterRCNN::FasterRCNN(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
has_nms_ = true;
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool FasterRCNN::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
float scale[2] = {1.0, 1.0};
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
if (processors_[i]->Name().find("Resize") != std::string::npos) {
scale[0] = mat->Height() * 1.0 / origin_h;
scale[1] = mat->Width() * 1.0 / origin_w;
}
}
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = scale[0];
ptr2[1] = scale[1];
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,39 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL FasterRCNN : public PPYOLOE {
public:
FasterRCNN(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; }
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual bool Initialize();
protected:
FasterRCNN() {}
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,64 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/yolov3.h"
namespace fastdeploy {
namespace vision {
namespace detection {
YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool YOLOv3::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
outputs->resize(3);
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
ptr0[0] = mat->Height();
ptr0[1] = mat->Width();
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
ptr2[0] = mat->Height() * 1.0 / origin_h;
ptr2[1] = mat->Width() * 1.0 / origin_w;
(*outputs)[1].name = "image";
mat->ShareWithTensor(&((*outputs)[1]));
// reshape to [1, c, h, w]
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,35 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL YOLOv3 : public PPYOLOE {
public:
YOLOv3(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; }
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,74 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/yolox.h"
namespace fastdeploy {
namespace vision {
namespace detection {
PaddleYOLOX::PaddleYOLOX(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
background_label = -1;
keep_top_k = 1000;
nms_eta = 1;
nms_threshold = 0.65;
nms_top_k = 10000;
normalized = true;
score_threshold = 0.001;
initialized = Initialize();
}
bool PaddleYOLOX::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
float scale[2] = {1.0, 1.0};
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
if (processors_[i]->Name().find("Resize") != std::string::npos) {
scale[0] = mat->Height() * 1.0 / origin_h;
scale[1] = mat->Width() * 1.0 / origin_w;
}
}
outputs->resize(2);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
ptr[0] = scale[0];
ptr[1] = scale[1];
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,35 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL PaddleYOLOX : public PPYOLOE {
public:
PaddleYOLOX(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -23,5 +23,4 @@ from .contrib.yolov5lite import YOLOv5Lite
from .contrib.yolov6 import YOLOv6 from .contrib.yolov6 import YOLOv6
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN from .ppdet import *
from .rknpu2 import RKPicoDet

View File

@@ -19,6 +19,40 @@ from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C from .... import c_lib_wrap as C
class PaddleDetPreprocessor:
def __init__(self, config_file):
"""Create a preprocessor for PaddleDetection Model from configuration file
:param config_file: (str)Path of configuration file, e.g ppyoloe/infer_cfg.yml
"""
self._preprocessor = C.vision.detection.PaddleDetPreprocessor(
config_file)
def run(self, input_ims):
"""Preprocess input images for PaddleDetection Model
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor, include image, scale_factor, im_shape
"""
return self._preprocessor.run(input_ims)
class PaddleDetPostprocessor:
def __init__(self):
"""Create a postprocessor for PaddleDetection Model
"""
self._postprocessor = C.vision.detection.PaddleDetPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for PaddleDetection Model
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:return: list of ClassifyResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class PPYOLOE(FastDeployModel): class PPYOLOE(FastDeployModel):
def __init__(self, def __init__(self,
model_file, model_file,
@@ -52,6 +86,31 @@ class PPYOLOE(FastDeployModel):
assert im is not None, "The input image data is None." assert im is not None, "The input image data is None."
return self._model.predict(im) return self._model.predict(im)
def batch_predict(self, images):
"""Detect a batch of input image list
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of DetectionResult
"""
return self._model.batch_predict(images)
@property
def preprocessor(self):
"""Get PaddleDetPreprocessor object of the loaded model
:return PaddleDetPreprocessor
"""
return self._model.preprocessor
@property
def postprocessor(self):
"""Get PaddleDetPostprocessor object of the loaded model
:return PaddleDetPostprocessor
"""
return self._model.postprocessor
class PPYOLO(PPYOLOE): class PPYOLO(PPYOLOE):
def __init__(self, def __init__(self,
@@ -77,31 +136,6 @@ class PPYOLO(PPYOLOE):
assert self.initialized, "PPYOLO model initialize failed." assert self.initialized, "PPYOLO model initialize failed."
class PPYOLOv2(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a PPYOLOv2 model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ppyolov2/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ppyolov2/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now."
self._model = C.vision.detection.PPYOLOv2(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "PPYOLOv2 model initialize failed."
class PaddleYOLOX(PPYOLOE): class PaddleYOLOX(PPYOLOE):
def __init__(self, def __init__(self,
model_file, model_file,
@@ -202,7 +236,7 @@ class YOLOv3(PPYOLOE):
assert self.initialized, "YOLOv3 model initialize failed." assert self.initialized, "YOLOv3 model initialize failed."
class MaskRCNN(FastDeployModel): class MaskRCNN(PPYOLOE):
def __init__(self, def __init__(self,
model_file, model_file,
params_file, params_file,
@@ -211,14 +245,14 @@ class MaskRCNN(FastDeployModel):
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a MaskRCNN model exported by PaddleDetection. """Load a MaskRCNN model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g maskrcnn/model.pdmodel :param model_file: (str)Path of model file, e.g fasterrcnn/model.pdmodel
:param params_file: (str)Path of parameters file, e.g maskrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string :param params_file: (str)Path of parameters file, e.g fasterrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
""" """
super(MaskRCNN, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now."
self._model = C.vision.detection.MaskRCNN( self._model = C.vision.detection.MaskRCNN(
@@ -226,6 +260,12 @@ class MaskRCNN(FastDeployModel):
model_format) model_format)
assert self.initialized, "MaskRCNN model initialize failed." assert self.initialized, "MaskRCNN model initialize failed."
def predict(self, input_image): def batch_predict(self, images):
assert input_image is not None, "The input image data is None." """Detect a batch of input image list, batch_predict is not supported for maskrcnn now.
return self._model.predict(input_image)
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of DetectionResult
"""
raise Exception(
"batch_predict is not supported for MaskRCNN model now.")

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
print(fd.__path__)
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_faster_rcnn():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/faster_rcnn_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/faster_rcnn_r50_vd_fpn_2x_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_paddle_backend()
model = fd.vision.detection.FasterRCNN(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/faster_rcnn_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("faster_rcnn_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_faster_rcnn()

68
tests/models/test_mask_rcnn.py Executable file
View File

@@ -0,0 +1,68 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_mask_rcnn():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/mask_rcnn_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/mask_rcnn_r50_1x_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
model = fd.vision.detection.MaskRCNN(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/mask_rcnn_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("mask_rcnn_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_mask_rcnn()

69
tests/models/test_picodet.py Executable file
View File

@@ -0,0 +1,69 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_picodet():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/picodet_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/picodet_l_320_coco_lcnet"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_paddle_backend()
model = fd.vision.detection.PicoDet(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/picodet_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-02, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("picodet_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_picodet()

70
tests/models/test_pp_yolox.py Executable file
View File

@@ -0,0 +1,70 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
print(fd.__path__)
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_yolox():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolox_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/yolox_s_300e_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_ort_backend()
model = fd.vision.detection.PaddleYOLOX(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/ppyolox_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.1
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("ppyolox_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_yolox()

69
tests/models/test_ppyolo.py Executable file
View File

@@ -0,0 +1,69 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_ppyolo():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolo_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/ppyolov2_r101vd_dcn_365e_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_paddle_backend()
model = fd.vision.detection.PPYOLO(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/ppyolo_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("ppyolo_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_ppyolo()

68
tests/models/test_ppyoloe.py Executable file
View File

@@ -0,0 +1,68 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_ppyoloe():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyoloe_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/ppyoloe_crn_l_300e_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_ort_backend()
model = fd.vision.detection.PPYOLOE(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/ppyoloe_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# with open("ppyoloe_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_ppyoloe()

69
tests/models/test_yolov3.py Executable file
View File

@@ -0,0 +1,69 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_yolov3():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/yolov3_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/yolov3_darknet53_270e_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
rc.test_option.use_ort_backend()
model = fd.vision.detection.YOLOv3(
model_file, params_file, config_file, runtime_option=rc.test_option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
result = model.predict(im1)
with open("resources/yolov3_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.1
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# result = model.predict(im1)
# with open("yolov3_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
if __name__ == "__main__":
test_detection_yolov3()