diff --git a/examples/vision/detection/paddledetection/python/infer_mask_rcnn.py b/examples/vision/detection/paddledetection/python/infer_mask_rcnn.py index 7d3c2e5e4..8b6c49aa7 100644 --- a/examples/vision/detection/paddledetection/python/infer_mask_rcnn.py +++ b/examples/vision/detection/paddledetection/python/infer_mask_rcnn.py @@ -66,4 +66,3 @@ print(result) vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) cv2.imwrite("visualized_result.jpg", vis_im) print("Visualized result save in ./visualized_result.jpg") -print(runtime_option) diff --git a/fastdeploy/pipeline/pptinypose/pipeline.cc b/fastdeploy/pipeline/pptinypose/pipeline.cc index 6a2f2d4ba..8c9d40a84 100644 --- a/fastdeploy/pipeline/pptinypose/pipeline.cc +++ b/fastdeploy/pipeline/pptinypose/pipeline.cc @@ -17,7 +17,7 @@ namespace fastdeploy { namespace pipeline { PPTinyPose::PPTinyPose( - fastdeploy::vision::detection::PPYOLOE* det_model, + fastdeploy::vision::detection::PicoDet* det_model, fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model) : detector_(det_model), pptinypose_model_(pptinypose_model) {} diff --git a/fastdeploy/pipeline/pptinypose/pipeline.h b/fastdeploy/pipeline/pptinypose/pipeline.h index 90d6e21f0..0cdbee399 100644 --- a/fastdeploy/pipeline/pptinypose/pipeline.h +++ b/fastdeploy/pipeline/pptinypose/pipeline.h @@ -35,7 +35,7 @@ class FASTDEPLOY_DECL PPTinyPose { * \param[in] pptinypose_model Initialized pptinypose model object */ PPTinyPose( - fastdeploy::vision::detection::PPYOLOE* det_model, + fastdeploy::vision::detection::PicoDet* det_model, fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model); /** \brief Predict the keypoint detection result for an input image @@ -52,7 +52,7 @@ class FASTDEPLOY_DECL PPTinyPose { float detection_model_score_threshold = 0; protected: - fastdeploy::vision::detection::PPYOLOE* detector_ = nullptr; + fastdeploy::vision::detection::PicoDet* detector_ = nullptr; fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model_ = nullptr; diff --git a/fastdeploy/pipeline/pptinypose/pptinyposepipeline_pybind.cc b/fastdeploy/pipeline/pptinypose/pptinyposepipeline_pybind.cc index 8aee0474e..b020bb1b4 100644 --- a/fastdeploy/pipeline/pptinypose/pptinyposepipeline_pybind.cc +++ b/fastdeploy/pipeline/pptinypose/pptinyposepipeline_pybind.cc @@ -18,31 +18,8 @@ namespace fastdeploy { void BindPPTinyPosePipeline(pybind11::module& m) { pybind11::class_(m, "PPTinyPose") - // explicitly pybind more kinds of detection models here - .def(pybind11::init()) - .def(pybind11::init()) - - .def(pybind11::init()) - - .def(pybind11::init()) - - .def(pybind11::init()) - - .def(pybind11::init()) - - .def(pybind11::init()) - - .def(pybind11::init()) - .def("predict", [](pipeline::PPTinyPose& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 15cc1d009..2f8c70661 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -29,7 +29,6 @@ #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/ppdet/model.h" -#include "fastdeploy/vision/detection/contrib/rknpu2/model.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h" #include "fastdeploy/vision/facedet/contrib/scrfd.h" #include "fastdeploy/vision/facedet/contrib/ultraface.h" diff --git a/fastdeploy/vision/detection/contrib/rknpu2/model.h b/fastdeploy/vision/detection/contrib/rknpu2/model.h deleted file mode 100644 index f0f8616ee..000000000 --- a/fastdeploy/vision/detection/contrib/rknpu2/model.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h" diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc b/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc deleted file mode 100644 index 6482ea675..000000000 --- a/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "fastdeploy/pybind/main.h" - -namespace fastdeploy { -void BindRKDet(pybind11::module& m) { - pybind11::class_(m, "RKPicoDet") - .def(pybind11::init()) - .def("predict", - [](vision::detection::RKPicoDet& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); -} -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc deleted file mode 100644 index 926214d86..000000000 --- a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h" -#include "yaml-cpp/yaml.h" -namespace fastdeploy { -namespace vision { -namespace detection { - -RKPicoDet::RKPicoDet(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::ORT}; - valid_rknpu_backends = {Backend::RKNPU2}; - if ((model_format == ModelFormat::RKNN) || - (model_format == ModelFormat::ONNX)) { - has_nms_ = false; - } - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - - // NMS parameters come from RKPicoDet_s_nms - background_label = -1; - keep_top_k = 100; - nms_eta = 1; - nms_threshold = 0.5; - nms_top_k = 1000; - normalized = true; - score_threshold = 0.3; - initialized = Initialize(); -} - -bool RKPicoDet::Initialize() { - if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; - return false; - } - if (!InitRuntime()) { - FDERROR << "Failed to initialize fastdeploy backend." << std::endl; - return false; - } - return true; -} - -bool RKPicoDet::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - } - - Cast::Run(mat, "float"); - - scale_factor.resize(2); - scale_factor[0] = mat->Height() * 1.0 / origin_h; - scale_factor[1] = mat->Width() * 1.0 / origin_w; - - outputs->resize(1); - (*outputs)[0].name = InputInfoOfRuntime(0).name; - mat->ShareWithTensor(&((*outputs)[0])); - // reshape to [1, c, h, w] - (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); - return true; -} - -bool RKPicoDet::BuildPreprocessPipelineFromConfig() { - processors_.clear(); - YAML::Node cfg; - try { - cfg = YAML::LoadFile(config_file_); - } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file_ - << ", maybe you should check this file." << std::endl; - return false; - } - - processors_.push_back(std::make_shared()); - - for (const auto& op : cfg["Preprocess"]) { - std::string op_name = op["type"].as(); - if (op_name == "NormalizeImage") { - continue; - } else if (op_name == "Resize") { - bool keep_ratio = op["keep_ratio"].as(); - auto target_size = op["target_size"].as>(); - int interp = op["interp"].as(); - FDASSERT(target_size.size() == 2, - "Require size of target_size be 2, but now it's %lu.", - target_size.size()); - if (!keep_ratio) { - int width = target_size[1]; - int height = target_size[0]; - processors_.push_back( - std::make_shared(width, height, -1.0, -1.0, interp, false)); - } else { - int min_target_size = std::min(target_size[0], target_size[1]); - int max_target_size = std::max(target_size[0], target_size[1]); - std::vector max_size; - if (max_target_size > 0) { - max_size.push_back(max_target_size); - max_size.push_back(max_target_size); - } - processors_.push_back(std::make_shared( - min_target_size, interp, true, max_size)); - } - } else if (op_name == "Permute") { - continue; - } else if (op_name == "Pad") { - auto size = op["size"].as>(); - auto value = op["fill_value"].as>(); - processors_.push_back(std::make_shared("float")); - processors_.push_back( - std::make_shared(size[1], size[0], value)); - } else if (op_name == "PadStride") { - auto stride = op["stride"].as(); - processors_.push_back( - std::make_shared(stride, std::vector(3, 0))); - } else { - FDERROR << "Unexcepted preprocess operator: " << op_name << "." - << std::endl; - return false; - } - } - return true; -} - -bool RKPicoDet::Postprocess(std::vector& infer_result, - DetectionResult* result) { - FDASSERT(infer_result[1].shape[0] == 1, - "Only support batch = 1 in FastDeploy now."); - - if (!has_nms_) { - int boxes_index = 0; - int scores_index = 1; - if (infer_result[0].shape[1] == infer_result[1].shape[2]) { - boxes_index = 0; - scores_index = 1; - } else if (infer_result[0].shape[2] == infer_result[1].shape[1]) { - boxes_index = 1; - scores_index = 0; - } else { - FDERROR << "The shape of boxes and scores should be [batch, boxes_num, " - "4], [batch, classes_num, boxes_num]" - << std::endl; - return false; - } - - backend::MultiClassNMS nms; - nms.background_label = background_label; - nms.keep_top_k = keep_top_k; - nms.nms_eta = nms_eta; - nms.nms_threshold = nms_threshold; - nms.score_threshold = score_threshold; - nms.nms_top_k = nms_top_k; - nms.normalized = normalized; - nms.Compute(static_cast(infer_result[boxes_index].Data()), - static_cast(infer_result[scores_index].Data()), - infer_result[boxes_index].shape, - infer_result[scores_index].shape); - if (nms.out_num_rois_data[0] > 0) { - result->Reserve(nms.out_num_rois_data[0]); - } - for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) { - result->label_ids.push_back(nms.out_box_data[i * 6]); - result->scores.push_back(nms.out_box_data[i * 6 + 1]); - result->boxes.emplace_back( - std::array{nms.out_box_data[i * 6 + 2] / scale_factor[1], - nms.out_box_data[i * 6 + 3] / scale_factor[0], - nms.out_box_data[i * 6 + 4] / scale_factor[1], - nms.out_box_data[i * 6 + 5] / scale_factor[0]}); - } - } else { - FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl; - } - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h deleted file mode 100644 index dbb48c16d..000000000 --- a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { -class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE { - public: - RKPicoDet(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::RKNN); - - virtual std::string ModelName() const { return "RKPicoDet"; } - - protected: - /// Build the preprocess pipeline from the loaded model - virtual bool BuildPreprocessPipelineFromConfig(); - /// Preprocess an input image, and set the preprocessed results to `outputs` - virtual bool Preprocess(Mat* mat, std::vector* outputs); - - /// Postprocess the inferenced results, and set the final result to `result` - virtual bool Postprocess(std::vector& infer_result, - DetectionResult* result); - virtual bool Initialize(); - private: - std::vector scale_factor{1.0, 1.0}; -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/detection_pybind.cc b/fastdeploy/vision/detection/detection_pybind.cc index f55bf68bf..b3a7a6ad9 100644 --- a/fastdeploy/vision/detection/detection_pybind.cc +++ b/fastdeploy/vision/detection/detection_pybind.cc @@ -27,8 +27,6 @@ void BindNanoDetPlus(pybind11::module& m); void BindPPDet(pybind11::module& m); void BindYOLOv7End2EndTRT(pybind11::module& m); void BindYOLOv7End2EndORT(pybind11::module& m); -void BindRKDet(pybind11::module& m); - void BindDetection(pybind11::module& m) { auto detection_module = @@ -44,6 +42,5 @@ void BindDetection(pybind11::module& m) { BindNanoDetPlus(detection_module); BindYOLOv7End2EndTRT(detection_module); BindYOLOv7End2EndORT(detection_module); - BindRKDet(detection_module); } } // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/base.cc b/fastdeploy/vision/detection/ppdet/base.cc new file mode 100755 index 000000000..1db42f158 --- /dev/null +++ b/fastdeploy/vision/detection/ppdet/base.cc @@ -0,0 +1,68 @@ +#include "fastdeploy/vision/detection/ppdet/base.h" +#include "fastdeploy/vision/utils/utils.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) : preprocessor_(config_file) { + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; +} + +bool PPDetBase::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool PPDetBase::Predict(cv::Mat* im, DetectionResult* result) { + return Predict(*im, result); +} + +bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool PPDetBase::BatchPredict(const std::vector& imgs, std::vector* results) { + std::vector fd_images = WrapMat(imgs); + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; + } + reused_input_tensors_[0].name = "image"; + reused_input_tensors_[1].name = "scale_factor"; + reused_input_tensors_[2].name = "im_shape"; + // Some models don't need im_shape as input + if (NumInputsOfRuntime() == 2) { + reused_input_tensors_.pop_back(); + } + + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results)) { + FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; + return false; + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/ppyoloe.h b/fastdeploy/vision/detection/ppdet/base.h similarity index 61% rename from fastdeploy/vision/detection/ppdet/ppyoloe.h rename to fastdeploy/vision/detection/ppdet/base.h index fd2a71cb1..bffc477a5 100644 --- a/fastdeploy/vision/detection/ppdet/ppyoloe.h +++ b/fastdeploy/vision/detection/ppdet/base.h @@ -14,6 +14,8 @@ #pragma once #include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/detection/ppdet/preprocessor.h" +#include "fastdeploy/vision/detection/ppdet/postprocessor.h" #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" @@ -26,9 +28,9 @@ namespace vision { */ namespace detection { -/*! @brief PPYOLOE model object used when to load a PPYOLOE model exported by PaddleDetection +/*! @brief Base model object used when to load a model exported by PaddleDetection */ -class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { +class FASTDEPLOY_DECL PPDetBase : public FastDeployModel { public: /** \brief Set path of model file and configuration file, and the configuration of runtime * @@ -38,49 +40,49 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` * \param[in] model_format Model format of the loaded model, default is Paddle format */ - PPYOLOE(const std::string& model_file, const std::string& params_file, + PPDetBase(const std::string& model_file, const std::string& params_file, const std::string& config_file, const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name - virtual std::string ModelName() const { return "PaddleDetection/PPYOLOE"; } + virtual std::string ModelName() const { return "PaddleDetection/BaseModel"; } - /** \brief Predict the detection result for an input image + /** \brief DEPRECATED Predict the detection result for an input image * * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format - * \param[in] result The output detection result will be writen to this structure + * \param[in] result The output detection result * \return true if the prediction successed, otherwise false */ virtual bool Predict(cv::Mat* im, DetectionResult* result); + /** \brief Predict the detection result for an input image + * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& im, DetectionResult* result); + + /** \brief Predict the detection result for an input image list + * \param[in] im The input image list, all the elements come from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] results The output detection result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + PaddleDetPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + PaddleDetPostprocessor& GetPostprocessor() { + return postprocessor_; + } + protected: - PPYOLOE() {} virtual bool Initialize(); - /// Build the preprocess pipeline from the loaded model - virtual bool BuildPreprocessPipelineFromConfig(); - /// Preprocess an input image, and set the preprocessed results to `outputs` - virtual bool Preprocess(Mat* mat, std::vector* outputs); - - /// Postprocess the inferenced results, and set the final result to `result` - virtual bool Postprocess(std::vector& infer_result, - DetectionResult* result); - - std::vector> processors_; - std::string config_file_; - // configuration for nms - int64_t background_label = -1; - int64_t keep_top_k = 300; - float nms_eta = 1.0; - float nms_threshold = 0.7; - float score_threshold = 0.01; - int64_t nms_top_k = 10000; - bool normalized = true; - bool has_nms_ = true; - - // This function will used to check if this model contains multiclass_nms - // and get parameters from the operator - void GetNmsInfo(); + PaddleDetPreprocessor preprocessor_; + PaddleDetPostprocessor postprocessor_; }; } // namespace detection diff --git a/fastdeploy/vision/detection/ppdet/mask_rcnn.cc b/fastdeploy/vision/detection/ppdet/mask_rcnn.cc deleted file mode 100644 index 7c656c669..000000000 --- a/fastdeploy/vision/detection/ppdet/mask_rcnn.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -MaskRCNN::MaskRCNN(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::PDINFER}; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -bool MaskRCNN::Postprocess(std::vector& infer_result, - DetectionResult* result) { - // index 0: bbox_data [N, 6] float32 - // index 1: bbox_num [B=1] int32 - // index 2: mask_data [N, h, w] int32 - FDASSERT(infer_result[1].shape[0] == 1, - "Only support batch = 1 in FastDeploy now."); - FDASSERT(infer_result.size() == 3, - "The infer_result must contains 3 otuput Tensors, but found %lu", - infer_result.size()); - - FDTensor& box_tensor = infer_result[0]; - FDTensor& box_num_tensor = infer_result[1]; - FDTensor& mask_tensor = infer_result[2]; - - int box_num = 0; - if (box_num_tensor.dtype == FDDataType::INT32) { - box_num = *(static_cast(box_num_tensor.Data())); - } else if (box_num_tensor.dtype == FDDataType::INT64) { - box_num = *(static_cast(box_num_tensor.Data())); - } else { - FDASSERT(false, - "The output box_num of PaddleDetection/MaskRCNN model should be " - "type of int32/int64."); - } - if (box_num <= 0) { - return true; // no object detected. - } - result->Resize(box_num); - float* box_data = static_cast(box_tensor.Data()); - for (size_t i = 0; i < box_num; ++i) { - result->label_ids[i] = static_cast(box_data[i * 6]); - result->scores[i] = box_data[i * 6 + 1]; - result->boxes[i] = - std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], - box_data[i * 6 + 4], box_data[i * 6 + 5]}; - } - result->contain_masks = true; - // TODO(qiuyanjun): Cast int64/int8 to int32. - FDASSERT(mask_tensor.dtype == FDDataType::INT32, - "The dtype of mask Tensor must be int32 now!"); - // In PaddleDetection/MaskRCNN, the mask_h and mask_w - // are already aligned with original input image. So, - // we need to crop it from output mask according to - // the detected bounding box. - // +-----------------------+ - // | x1,y1 | - // | +---------------+ | - // | | | | - // | | Crop | | - // | | | | - // | | | | - // | +---------------+ | - // | x2,y2 | - // +-----------------------+ - int64_t out_mask_h = mask_tensor.shape[1]; - int64_t out_mask_w = mask_tensor.shape[2]; - int64_t out_mask_numel = out_mask_h * out_mask_w; - int32_t* out_mask_data = static_cast(mask_tensor.Data()); - for (size_t i = 0; i < box_num; ++i) { - // crop instance mask according to box - int64_t x1 = static_cast(result->boxes[i][0]); - int64_t y1 = static_cast(result->boxes[i][1]); - int64_t x2 = static_cast(result->boxes[i][2]); - int64_t y2 = static_cast(result->boxes[i][3]); - int64_t keep_mask_h = y2 - y1; - int64_t keep_mask_w = x2 - x1; - int64_t keep_mask_numel = keep_mask_h * keep_mask_w; - result->masks[i].Resize(keep_mask_numel); // int32 - result->masks[i].shape = {keep_mask_h, keep_mask_w}; - int32_t* mask_start_ptr = out_mask_data + i * out_mask_numel; - int32_t* keep_mask_ptr = static_cast(result->masks[i].Data()); - for (size_t row = y1; row < y2; ++row) { - size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t); - int32_t* out_row_start_ptr = mask_start_ptr + row * out_mask_w + x1; - int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w; - std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col); - } - } - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/model.h b/fastdeploy/vision/detection/ppdet/model.h index f8a40d64f..90a92f893 100644 --- a/fastdeploy/vision/detection/ppdet/model.h +++ b/fastdeploy/vision/detection/ppdet/model.h @@ -13,10 +13,152 @@ // limitations under the License. #pragma once -#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h" -#include "fastdeploy/vision/detection/ppdet/picodet.h" -#include "fastdeploy/vision/detection/ppdet/ppyolo.h" -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" -#include "fastdeploy/vision/detection/ppdet/rcnn.h" -#include "fastdeploy/vision/detection/ppdet/yolov3.h" -#include "fastdeploy/vision/detection/ppdet/yolox.h" +#include "fastdeploy/vision/detection/ppdet/base.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +class FASTDEPLOY_DECL PicoDet : public PPDetBase { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g picodet/model.pdmodel + * \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + PicoDet(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, + Backend::PDINFER, Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PicoDet"; } +}; + +class FASTDEPLOY_DECL PPYOLOE : public PPDetBase { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g ppyoloe/model.pdmodel + * \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + PPYOLOE(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, + Backend::PDINFER, Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PPYOLOE"; } +}; + +class FASTDEPLOY_DECL PPYOLO : public PPDetBase { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g ppyolo/model.pdmodel + * \param[in] params_file Path of parameter file, e.g ppyolo/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + PPYOLO(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/PP-YOLO"; } +}; + +class FASTDEPLOY_DECL YOLOv3 : public PPDetBase { + public: + YOLOv3(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, + Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; } +}; + +class FASTDEPLOY_DECL PaddleYOLOX : public PPDetBase { + public: + PaddleYOLOX(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, + Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; } +}; + +class FASTDEPLOY_DECL FasterRCNN : public PPDetBase { + public: + FasterRCNN(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::PDINFER, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; } +}; + +class FASTDEPLOY_DECL MaskRCNN : public PPDetBase { + public: + MaskRCNN(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::PDINFER, Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; } +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/picodet.cc b/fastdeploy/vision/detection/ppdet/picodet.cc deleted file mode 100644 index 9b67db4a7..000000000 --- a/fastdeploy/vision/detection/ppdet/picodet.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/picodet.h" -#include "yaml-cpp/yaml.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -PicoDet::PicoDet(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE}; - valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - background_label = -1; - keep_top_k = 100; - nms_eta = 1; - nms_threshold = 0.6; - nms_top_k = 1000; - normalized = true; - score_threshold = 0.025; - CheckIfContainDecodeAndNMS(); - initialized = Initialize(); -} - -bool PicoDet::CheckIfContainDecodeAndNMS() { - YAML::Node cfg; - try { - cfg = YAML::LoadFile(config_file_); - } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file_ - << ", maybe you should check this file." << std::endl; - return false; - } - - if (cfg["arch"].as() == "PicoDet") { - FDERROR << "The arch in config file is PicoDet, which means this model " - "doesn contain box decode and nms, please export model with " - "decode and nms." - << std::endl; - return false; - } - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/picodet.h b/fastdeploy/vision/detection/ppdet/picodet.h deleted file mode 100644 index 5f85d2dd9..000000000 --- a/fastdeploy/vision/detection/ppdet/picodet.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -class FASTDEPLOY_DECL PicoDet : public PPYOLOE { - public: - PicoDet(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - // Only support picodet contains decode and nms - bool CheckIfContainDecodeAndNMS(); - - virtual std::string ModelName() const { return "PicoDet"; } -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/postprocessor.cc b/fastdeploy/vision/detection/ppdet/postprocessor.cc new file mode 100644 index 000000000..5e8312a7d --- /dev/null +++ b/fastdeploy/vision/detection/ppdet/postprocessor.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/ppdet/postprocessor.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector* results) { + auto shape = tensor.Shape(); + if (tensor.Dtype() != FDDataType::INT32) { + FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl; + return false; + } + int64_t out_mask_h = shape[1]; + int64_t out_mask_w = shape[2]; + int64_t out_mask_numel = shape[1] * shape[2]; + const int32_t* data = reinterpret_cast(tensor.CpuData()); + int index = 0; + + for (int i = 0; i < results->size(); ++i) { + (*results)[i].contain_masks = true; + (*results)[i].masks.resize((*results)[i].boxes.size()); + for (int j = 0; j < (*results)[i].boxes.size(); ++j) { + int x1 = static_cast((*results)[i].boxes[j][0]); + int y1 = static_cast((*results)[i].boxes[j][1]); + int x2 = static_cast((*results)[i].boxes[j][2]); + int y2 = static_cast((*results)[i].boxes[j][3]); + int keep_mask_h = y2 - y1; + int keep_mask_w = x2 - x1; + int keep_mask_numel = keep_mask_h * keep_mask_w; + (*results)[i].masks[j].Resize(keep_mask_numel); + (*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w}; + const int32_t* current_ptr = data + index * out_mask_numel; + + int32_t* keep_mask_ptr = reinterpret_cast((*results)[i].masks[j].Data()); + for (int row = y1; row < y2; ++row) { + size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t); + const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1; + int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w; + std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col); + } + index += 1; + } + } + return true; +} + +bool PaddleDetPostprocessor::Run(const std::vector& tensors, std::vector* results) { + if (tensors[0].shape[0] == 0) { + // No detected boxes + return true; + } + + // Get number of boxes for each input image + std::vector num_boxes(tensors[1].shape[0]); + int total_num_boxes = 0; + if (tensors[1].dtype == FDDataType::INT32) { + const int32_t* data = static_cast(tensors[1].CpuData()); + for (size_t i = 0; i < tensors[1].shape[0]; ++i) { + num_boxes[i] = static_cast(data[i]); + total_num_boxes += num_boxes[i]; + } + } else if (tensors[1].dtype == FDDataType::INT64) { + const int64_t* data = static_cast(tensors[1].CpuData()); + for (size_t i = 0; i < tensors[1].shape[0]; ++i) { + num_boxes[i] = static_cast(data[i]); + } + } + + // Special case for TensorRT, it has fixed output shape of NMS + // So there's invalid boxes in its' output boxes + int num_output_boxes = tensors[0].Shape()[0]; + bool contain_invalid_boxes = false; + if (total_num_boxes != num_output_boxes) { + if (num_output_boxes % num_boxes.size() == 0) { + contain_invalid_boxes = true; + } else { + FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl; + return false; + } + } + + // Get boxes for each input image + results->resize(num_boxes.size()); + const float* box_data = static_cast(tensors[0].CpuData()); + int offset = 0; + for (size_t i = 0; i < num_boxes.size(); ++i) { + const float* ptr = box_data + offset; + (*results)[i].Reserve(num_boxes[i]); + for (size_t j = 0; j < num_boxes[i]; ++j) { + (*results)[i].label_ids.push_back(static_cast(round(ptr[j * 6]))); + (*results)[i].scores.push_back(ptr[j * 6 + 1]); + (*results)[i].boxes.emplace_back(std::array({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]})); + } + if (contain_invalid_boxes) { + offset += (num_output_boxes * 6 / num_boxes.size()); + } else { + offset += (num_boxes[i] * 6); + } + } + + // Only detection + if (tensors.size() <= 2) { + return true; + } + + if (tensors[2].Shape()[0] != num_output_boxes) { + FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl; + return false; + } + + // process for maskrcnn + return ProcessMask(tensors[2], results); +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/mask_rcnn.h b/fastdeploy/vision/detection/ppdet/postprocessor.h similarity index 50% rename from fastdeploy/vision/detection/ppdet/mask_rcnn.h rename to fastdeploy/vision/detection/ppdet/postprocessor.h index a24d1f42c..54be1bfd9 100644 --- a/fastdeploy/vision/detection/ppdet/mask_rcnn.h +++ b/fastdeploy/vision/detection/ppdet/postprocessor.h @@ -13,26 +13,30 @@ // limitations under the License. #pragma once -#include "fastdeploy/vision/detection/ppdet/rcnn.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" namespace fastdeploy { namespace vision { + namespace detection { - -class FASTDEPLOY_DECL MaskRCNN : public FasterRCNN { +/*! @brief Postprocessor object for PaddleDet serials model. + */ +class FASTDEPLOY_DECL PaddleDetPostprocessor { public: - MaskRCNN(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; } - - virtual bool Postprocess(std::vector& infer_result, - DetectionResult* result); - - protected: - MaskRCNN() {} + PaddleDetPostprocessor() = default; + /** \brief Process the result of runtime and fill to ClassifyResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] result The output result of detection + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* result); + private: + // Process mask tensor for MaskRCNN + bool ProcessMask(const FDTensor& tensor, + std::vector* results); }; } // namespace detection diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc index 01a6a8ce1..252097608 100644 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -15,94 +15,94 @@ namespace fastdeploy { void BindPPDet(pybind11::module& m) { - pybind11::class_(m, "PPYOLOE") - .def(pybind11::init()) - .def("predict", - [](vision::detection::PPYOLOE& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + pybind11::class_( + m, "PaddleDetPreprocessor") + .def(pybind11::init()) + .def("run", [](vision::detection::PaddleDetPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleDetPreprocessor.')"); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return outputs; + }); - pybind11::class_(m, "PPYOLO") - .def(pybind11::init()) - .def("predict", - [](vision::detection::PPYOLO& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + pybind11::class_( + m, "PaddleDetPostprocessor") + .def(pybind11::init<>()) + .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector& inputs) { + std::vector results; + if (!self.Run(inputs, &results)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')"); + } + return results; + }) + .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector& input_array) { + std::vector results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results)) { + pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')"); + } + return results; + }); - pybind11::class_(m, "PPYOLOv2") + pybind11::class_(m, "PPDetBase") .def(pybind11::init()) .def("predict", - [](vision::detection::PPYOLOv2& self, pybind11::array& data) { + [](vision::detection::PPDetBase& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); vision::DetectionResult res; self.Predict(&mat, &res); return res; - }); + }) + .def("batch_predict", + [](vision::detection::PPDetBase& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }) + .def_property_readonly("preprocessor", &vision::detection::PPDetBase::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::detection::PPDetBase::GetPostprocessor); - pybind11::class_(m, "PicoDet") - .def(pybind11::init()) - .def("predict", - [](vision::detection::PicoDet& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); - pybind11::class_( - m, "PaddleYOLOX") + pybind11::class_(m, "PPYOLO") .def(pybind11::init()) - .def("predict", - [](vision::detection::PaddleYOLOX& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + ModelFormat>()); - pybind11::class_(m, - "FasterRCNN") + pybind11::class_(m, "PPYOLOE") .def(pybind11::init()) - .def("predict", - [](vision::detection::FasterRCNN& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + ModelFormat>()); - pybind11::class_(m, "YOLOv3") + pybind11::class_(m, "PicoDet") .def(pybind11::init()) - .def("predict", - [](vision::detection::YOLOv3& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + ModelFormat>()); - pybind11::class_(m, "MaskRCNN") + pybind11::class_(m, "PaddleYOLOX") .def(pybind11::init()) - .def("predict", - [](vision::detection::MaskRCNN& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::DetectionResult res; - self.Predict(&mat, &res); - return res; - }); + ModelFormat>()); + + pybind11::class_(m, "FasterRCNN") + .def(pybind11::init()); + + pybind11::class_(m, "YOLOv3") + .def(pybind11::init()); + + pybind11::class_(m, "MaskRCNN") + .def(pybind11::init()); } } // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/ppyolo.cc b/fastdeploy/vision/detection/ppdet/ppyolo.cc deleted file mode 100644 index f0965e5f4..000000000 --- a/fastdeploy/vision/detection/ppdet/ppyolo.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/ppyolo.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -PPYOLO::PPYOLO(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::PDINFER}; - has_nms_ = true; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -bool PPYOLO::Initialize() { - if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; - return false; - } - if (!InitRuntime()) { - FDERROR << "Failed to initialize fastdeploy backend." << std::endl; - return false; - } - return true; -} - -bool PPYOLO::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - } - - outputs->resize(3); - (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); - (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); - float* ptr0 = static_cast((*outputs)[0].MutableData()); - ptr0[0] = mat->Height(); - ptr0[1] = mat->Width(); - float* ptr2 = static_cast((*outputs)[2].MutableData()); - ptr2[0] = mat->Height() * 1.0 / origin_h; - ptr2[1] = mat->Width() * 1.0 / origin_w; - (*outputs)[1].name = "image"; - mat->ShareWithTensor(&((*outputs)[1])); - // reshape to [1, c, h, w] - (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/ppyolo.h b/fastdeploy/vision/detection/ppdet/ppyolo.h deleted file mode 100644 index 6f288a9db..000000000 --- a/fastdeploy/vision/detection/ppdet/ppyolo.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -class FASTDEPLOY_DECL PPYOLO : public PPYOLOE { - public: - PPYOLO(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - virtual std::string ModelName() const { return "PaddleDetection/PPYOLO"; } - - virtual bool Preprocess(Mat* mat, std::vector* outputs); - virtual bool Initialize(); - - protected: - PPYOLO() {} -}; - -class FASTDEPLOY_DECL PPYOLOv2 : public PPYOLO { - public: - PPYOLOv2(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE) - : PPYOLO(model_file, params_file, config_file, custom_option, - model_format) {} - - virtual std::string ModelName() const { return "PaddleDetection/PPYOLOv2"; } -}; - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/ppyoloe.cc b/fastdeploy/vision/detection/ppdet/ppyoloe.cc deleted file mode 100755 index 1ae6294ba..000000000 --- a/fastdeploy/vision/detection/ppdet/ppyoloe.cc +++ /dev/null @@ -1,274 +0,0 @@ -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" -#include "fastdeploy/vision/utils/utils.h" -#include "yaml-cpp/yaml.h" -#ifdef ENABLE_PADDLE_FRONTEND -#include "paddle2onnx/converter.h" -#endif - -namespace fastdeploy { -namespace vision { -namespace detection { - -PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -void PPYOLOE::GetNmsInfo() { -#ifdef ENABLE_PADDLE_FRONTEND - if (runtime_option.model_format == ModelFormat::PADDLE) { - std::string contents; - if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { - return; - } - auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); - if (reader.has_nms) { - has_nms_ = true; - background_label = reader.nms_params.background_label; - keep_top_k = reader.nms_params.keep_top_k; - nms_eta = reader.nms_params.nms_eta; - nms_threshold = reader.nms_params.nms_threshold; - score_threshold = reader.nms_params.score_threshold; - nms_top_k = reader.nms_params.nms_top_k; - normalized = reader.nms_params.normalized; - } - } -#endif -} - -bool PPYOLOE::Initialize() { - if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; - return false; - } - if (!InitRuntime()) { - FDERROR << "Failed to initialize fastdeploy backend." << std::endl; - return false; - } - reused_input_tensors_.resize(2); - - return true; -} - -bool PPYOLOE::BuildPreprocessPipelineFromConfig() { - processors_.clear(); - YAML::Node cfg; - try { - cfg = YAML::LoadFile(config_file_); - } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file_ - << ", maybe you should check this file." << std::endl; - return false; - } - - processors_.push_back(std::make_shared()); - - bool has_permute = false; - for (const auto& op : cfg["Preprocess"]) { - std::string op_name = op["type"].as(); - if (op_name == "NormalizeImage") { - auto mean = op["mean"].as>(); - auto std = op["std"].as>(); - bool is_scale = true; - if (op["is_scale"]) { - is_scale = op["is_scale"].as(); - } - std::string norm_type = "mean_std"; - if (op["norm_type"]) { - norm_type = op["norm_type"].as(); - } - if (norm_type != "mean_std") { - std::fill(mean.begin(), mean.end(), 0.0); - std::fill(std.begin(), std.end(), 1.0); - } - processors_.push_back(std::make_shared(mean, std, is_scale)); - } else if (op_name == "Resize") { - bool keep_ratio = op["keep_ratio"].as(); - auto target_size = op["target_size"].as>(); - int interp = op["interp"].as(); - FDASSERT(target_size.size() == 2, - "Require size of target_size be 2, but now it's %lu.", - target_size.size()); - if (!keep_ratio) { - int width = target_size[1]; - int height = target_size[0]; - processors_.push_back( - std::make_shared(width, height, -1.0, -1.0, interp, false)); - } else { - int min_target_size = std::min(target_size[0], target_size[1]); - int max_target_size = std::max(target_size[0], target_size[1]); - std::vector max_size; - if (max_target_size > 0) { - max_size.push_back(max_target_size); - max_size.push_back(max_target_size); - } - processors_.push_back(std::make_shared( - min_target_size, interp, true, max_size)); - } - } else if (op_name == "Permute") { - // Do nothing, do permute as the last operation - has_permute = true; - continue; - // processors_.push_back(std::make_shared()); - } else if (op_name == "Pad") { - auto size = op["size"].as>(); - auto value = op["fill_value"].as>(); - processors_.push_back(std::make_shared("float")); - processors_.push_back( - std::make_shared(size[1], size[0], value)); - } else if (op_name == "PadStride") { - auto stride = op["stride"].as(); - processors_.push_back( - std::make_shared(stride, std::vector(3, 0))); - } else { - FDERROR << "Unexcepted preprocess operator: " << op_name << "." - << std::endl; - return false; - } - } - if (has_permute) { - // permute = cast + HWC2CHW - processors_.push_back(std::make_shared("float")); - processors_.push_back(std::make_shared()); - } else { - processors_.push_back(std::make_shared()); - } - - // Fusion will improve performance - FuseTransforms(&processors_); - return true; -} - -bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - } - - outputs->resize(2); - (*outputs)[0].name = InputInfoOfRuntime(0).name; - mat->ShareWithTensor(&((*outputs)[0])); - - // reshape to [1, c, h, w] - (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); - - (*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name); - float* ptr = static_cast((*outputs)[1].MutableData()); - ptr[0] = mat->Height() * 1.0 / origin_h; - ptr[1] = mat->Width() * 1.0 / origin_w; - return true; -} - -bool PPYOLOE::Postprocess(std::vector& infer_result, - DetectionResult* result) { - FDASSERT(infer_result[1].shape[0] == 1, - "Only support batch = 1 in FastDeploy now."); - - has_nms_ = true; - if (!has_nms_) { - int boxes_index = 0; - int scores_index = 1; - if (infer_result[0].shape[1] == infer_result[1].shape[2]) { - boxes_index = 0; - scores_index = 1; - } else if (infer_result[0].shape[2] == infer_result[1].shape[1]) { - boxes_index = 1; - scores_index = 0; - } else { - FDERROR << "The shape of boxes and scores should be [batch, boxes_num, " - "4], [batch, classes_num, boxes_num]" - << std::endl; - return false; - } - - backend::MultiClassNMS nms; - nms.background_label = background_label; - nms.keep_top_k = keep_top_k; - nms.nms_eta = nms_eta; - nms.nms_threshold = nms_threshold; - nms.score_threshold = score_threshold; - nms.nms_top_k = nms_top_k; - nms.normalized = normalized; - nms.Compute(static_cast(infer_result[boxes_index].Data()), - static_cast(infer_result[scores_index].Data()), - infer_result[boxes_index].shape, - infer_result[scores_index].shape); - if (nms.out_num_rois_data[0] > 0) { - result->Reserve(nms.out_num_rois_data[0]); - } - for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) { - result->label_ids.push_back(nms.out_box_data[i * 6]); - result->scores.push_back(nms.out_box_data[i * 6 + 1]); - result->boxes.emplace_back(std::array{ - nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3], - nms.out_box_data[i * 6 + 4], nms.out_box_data[i * 6 + 5]}); - } - } else { - std::vector num_boxes(infer_result[1].shape[0]); - if (infer_result[1].dtype == FDDataType::INT32) { - int32_t* data = static_cast(infer_result[1].Data()); - for (size_t i = 0; i < infer_result[1].shape[0]; ++i) { - num_boxes[i] = static_cast(data[i]); - } - } else if (infer_result[1].dtype == FDDataType::INT64) { - int64_t* data = static_cast(infer_result[1].Data()); - for (size_t i = 0; i < infer_result[1].shape[0]; ++i) { - num_boxes[i] = static_cast(data[i]); - } - } - - // Only support batch = 1 now - result->Reserve(num_boxes[0]); - float* box_data = static_cast(infer_result[0].Data()); - for (size_t i = 0; i < num_boxes[0]; ++i) { - result->label_ids.push_back(box_data[i * 6]); - result->scores.push_back(box_data[i * 6 + 1]); - result->boxes.emplace_back( - std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], - box_data[i * 6 + 4], box_data[i * 6 + 5]}); - } - } - return true; -} - -bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) { - Mat mat(*im); - - if (!Preprocess(&mat, &reused_input_tensors_)) { - FDERROR << "Failed to preprocess input data while using model:" - << ModelName() << "." << std::endl; - return false; - } - - if (!Infer()) { - FDERROR << "Failed to inference while using model:" << ModelName() << "." - << std::endl; - return false; - } - - if (!Postprocess(reused_output_tensors_, result)) { - FDERROR << "Failed to postprocess while using model:" << ModelName() << "." - << std::endl; - return false; - } - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.cc b/fastdeploy/vision/detection/ppdet/preprocessor.cc new file mode 100644 index 000000000..b1179d036 --- /dev/null +++ b/fastdeploy/vision/detection/ppdet/preprocessor.cc @@ -0,0 +1,201 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/ppdet/preprocessor.h" +#include "fastdeploy/function/concat.h" +#include "fastdeploy/function/pad.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) { + FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor."); + initialized_ = true; +} + +bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) { + processors_.clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file + << ", maybe you should check this file." << std::endl; + return false; + } + + processors_.push_back(std::make_shared()); + + bool has_permute = false; + for (const auto& op : cfg["Preprocess"]) { + std::string op_name = op["type"].as(); + if (op_name == "NormalizeImage") { + auto mean = op["mean"].as>(); + auto std = op["std"].as>(); + bool is_scale = true; + if (op["is_scale"]) { + is_scale = op["is_scale"].as(); + } + std::string norm_type = "mean_std"; + if (op["norm_type"]) { + norm_type = op["norm_type"].as(); + } + if (norm_type != "mean_std") { + std::fill(mean.begin(), mean.end(), 0.0); + std::fill(std.begin(), std.end(), 1.0); + } + processors_.push_back(std::make_shared(mean, std, is_scale)); + } else if (op_name == "Resize") { + bool keep_ratio = op["keep_ratio"].as(); + auto target_size = op["target_size"].as>(); + int interp = op["interp"].as(); + FDASSERT(target_size.size() == 2, + "Require size of target_size be 2, but now it's %lu.", + target_size.size()); + if (!keep_ratio) { + int width = target_size[1]; + int height = target_size[0]; + processors_.push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else { + int min_target_size = std::min(target_size[0], target_size[1]); + int max_target_size = std::max(target_size[0], target_size[1]); + std::vector max_size; + if (max_target_size > 0) { + max_size.push_back(max_target_size); + max_size.push_back(max_target_size); + } + processors_.push_back(std::make_shared( + min_target_size, interp, true, max_size)); + } + } else if (op_name == "Permute") { + // Do nothing, do permute as the last operation + has_permute = true; + continue; + // processors_.push_back(std::make_shared()); + } else if (op_name == "Pad") { + auto size = op["size"].as>(); + auto value = op["fill_value"].as>(); + processors_.push_back(std::make_shared("float")); + processors_.push_back( + std::make_shared(size[1], size[0], value)); + } else if (op_name == "PadStride") { + auto stride = op["stride"].as(); + processors_.push_back( + std::make_shared(stride, std::vector(3, 0))); + } else { + FDERROR << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + if (has_permute) { + // permute = cast + HWC2CHW + processors_.push_back(std::make_shared("float")); + processors_.push_back(std::make_shared()); + } else { + processors_.push_back(std::make_shared()); + } + + // Fusion will improve performance + FuseTransforms(&processors_); + + return true; +} + +bool PaddleDetPreprocessor::Run(std::vector* images, std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + + // There are 3 outputs, image, scale_factor, im_shape + // But im_shape is not used for all the PaddleDetection models + // So preprocessor will output the 3 FDTensors, and how to use `im_shape` + // is decided by the model itself + outputs->resize(3); + int batch = static_cast(images->size()); + // Allocate memory for scale_factor + (*outputs)[1].Resize({batch, 2}, FDDataType::FP32); + // Allocate memory for im_shape + (*outputs)[2].Resize({batch, 2}, FDDataType::FP32); + // Record the max size for a batch of input image + // All the tensor will pad to the max size to compose a batched tensor + std::vector max_hw({-1, -1}); + + float* scale_factor_ptr = reinterpret_cast((*outputs)[1].MutableData()); + float* im_shape_ptr = reinterpret_cast((*outputs)[2].MutableData()); + for (size_t i = 0; i < images->size(); ++i) { + int origin_w = (*images)[i].Width(); + int origin_h = (*images)[i].Height(); + scale_factor_ptr[2 * i] = 1.0; + scale_factor_ptr[2 * i + 1] = 1.0; + for (size_t j = 0; j < processors_.size(); ++j) { + if (!(*(processors_[j].get()))(&((*images)[i]))) { + FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl; + return false; + } + if (processors_[j]->Name().find("Resize") != std::string::npos) { + scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h; + scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w; + } + } + if ((*images)[i].Height() > max_hw[0]) { + max_hw[0] = (*images)[i].Height(); + } + if ((*images)[i].Width() > max_hw[1]) { + max_hw[1] = (*images)[i].Width(); + } + im_shape_ptr[2 * i] = max_hw[0]; + im_shape_ptr[2 * i + 1] = max_hw[1]; + } + + // Concat all the preprocessed data to a batch tensor + std::vector im_tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) { + // if the size of image less than max_hw, pad to max_hw + FDTensor tensor; + (*images)[i].ShareWithTensor(&tensor); + function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0); + } else { + // No need pad + (*images)[i].ShareWithTensor(&(im_tensors[i])); + } + // Reshape to 1xCxHxW + im_tensors[i].ExpandDim(0); + } + + if (im_tensors.size() == 1) { + // If there's only 1 input, no need to concat + // skip memory copy + (*outputs)[0] = std::move(im_tensors[0]); + } else { + // Else concat the im tensor for each input image + // compose a batched input tensor + function::Concat(im_tensors, &((*outputs)[0]), 0); + } + + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.h b/fastdeploy/vision/detection/ppdet/preprocessor.h new file mode 100644 index 000000000..2733c450e --- /dev/null +++ b/fastdeploy/vision/detection/ppdet/preprocessor.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace detection { +/*! @brief Preprocessor object for PaddleDet serials model. + */ +class FASTDEPLOY_DECL PaddleDetPreprocessor { + public: + PaddleDetPreprocessor() = default; + /** \brief Create a preprocessor instance for PaddleDet serials model + * + * \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + */ + explicit PaddleDetPreprocessor(const std::string& config_file); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs); + + private: + bool BuildPreprocessPipelineFromConfig(const std::string& config_file); + std::vector> processors_; + bool initialized_ = false; +}; + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/rcnn.cc b/fastdeploy/vision/detection/ppdet/rcnn.cc deleted file mode 100644 index 53cbffa56..000000000 --- a/fastdeploy/vision/detection/ppdet/rcnn.cc +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/rcnn.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -FasterRCNN::FasterRCNN(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::PDINFER}; - has_nms_ = true; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -bool FasterRCNN::Initialize() { - if (!BuildPreprocessPipelineFromConfig()) { - FDERROR << "Failed to build preprocess pipeline from configuration file." - << std::endl; - return false; - } - if (!InitRuntime()) { - FDERROR << "Failed to initialize fastdeploy backend." << std::endl; - return false; - } - return true; -} - -bool FasterRCNN::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - float scale[2] = {1.0, 1.0}; - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - if (processors_[i]->Name().find("Resize") != std::string::npos) { - scale[0] = mat->Height() * 1.0 / origin_h; - scale[1] = mat->Width() * 1.0 / origin_w; - } - } - - outputs->resize(3); - (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); - (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); - float* ptr0 = static_cast((*outputs)[0].MutableData()); - ptr0[0] = mat->Height(); - ptr0[1] = mat->Width(); - float* ptr2 = static_cast((*outputs)[2].MutableData()); - ptr2[0] = scale[0]; - ptr2[1] = scale[1]; - (*outputs)[1].name = "image"; - mat->ShareWithTensor(&((*outputs)[1])); - // reshape to [1, c, h, w] - (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/rcnn.h b/fastdeploy/vision/detection/ppdet/rcnn.h deleted file mode 100644 index df42b8efc..000000000 --- a/fastdeploy/vision/detection/ppdet/rcnn.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -class FASTDEPLOY_DECL FasterRCNN : public PPYOLOE { - public: - FasterRCNN(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; } - - virtual bool Preprocess(Mat* mat, std::vector* outputs); - virtual bool Initialize(); - - protected: - FasterRCNN() {} -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/yolov3.cc b/fastdeploy/vision/detection/ppdet/yolov3.cc deleted file mode 100644 index bcfb3aef9..000000000 --- a/fastdeploy/vision/detection/ppdet/yolov3.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/yolov3.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - initialized = Initialize(); -} - -bool YOLOv3::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - } - - outputs->resize(3); - (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape"); - (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor"); - float* ptr0 = static_cast((*outputs)[0].MutableData()); - ptr0[0] = mat->Height(); - ptr0[1] = mat->Width(); - float* ptr2 = static_cast((*outputs)[2].MutableData()); - ptr2[0] = mat->Height() * 1.0 / origin_h; - ptr2[1] = mat->Width() * 1.0 / origin_w; - (*outputs)[1].name = "image"; - mat->ShareWithTensor(&((*outputs)[1])); - // reshape to [1, c, h, w] - (*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1); - return true; -} - -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/yolov3.h b/fastdeploy/vision/detection/ppdet/yolov3.h deleted file mode 100644 index ebafa6edd..000000000 --- a/fastdeploy/vision/detection/ppdet/yolov3.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -class FASTDEPLOY_DECL YOLOv3 : public PPYOLOE { - public: - YOLOv3(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; } - - virtual bool Preprocess(Mat* mat, std::vector* outputs); -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/yolox.cc b/fastdeploy/vision/detection/ppdet/yolox.cc deleted file mode 100644 index f7405d4de..000000000 --- a/fastdeploy/vision/detection/ppdet/yolox.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/detection/ppdet/yolox.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -PaddleYOLOX::PaddleYOLOX(const std::string& model_file, - const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option, - const ModelFormat& model_format) { - config_file_ = config_file; - valid_cpu_backends = {Backend::ORT, Backend::PDINFER, Backend::LITE}; - valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; - runtime_option = custom_option; - runtime_option.model_format = model_format; - runtime_option.model_file = model_file; - runtime_option.params_file = params_file; - background_label = -1; - keep_top_k = 1000; - nms_eta = 1; - nms_threshold = 0.65; - nms_top_k = 10000; - normalized = true; - score_threshold = 0.001; - initialized = Initialize(); -} - -bool PaddleYOLOX::Preprocess(Mat* mat, std::vector* outputs) { - int origin_w = mat->Width(); - int origin_h = mat->Height(); - float scale[2] = {1.0, 1.0}; - for (size_t i = 0; i < processors_.size(); ++i) { - if (!(*(processors_[i].get()))(mat)) { - FDERROR << "Failed to process image data in " << processors_[i]->Name() - << "." << std::endl; - return false; - } - if (processors_[i]->Name().find("Resize") != std::string::npos) { - scale[0] = mat->Height() * 1.0 / origin_h; - scale[1] = mat->Width() * 1.0 / origin_w; - } - } - - outputs->resize(2); - (*outputs)[0].name = InputInfoOfRuntime(0).name; - mat->ShareWithTensor(&((*outputs)[0])); - - // reshape to [1, c, h, w] - (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); - - (*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name); - float* ptr = static_cast((*outputs)[1].MutableData()); - ptr[0] = scale[0]; - ptr[1] = scale[1]; - return true; -} -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/yolox.h b/fastdeploy/vision/detection/ppdet/yolox.h deleted file mode 100644 index dd0a11b57..000000000 --- a/fastdeploy/vision/detection/ppdet/yolox.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" - -namespace fastdeploy { -namespace vision { -namespace detection { - -class FASTDEPLOY_DECL PaddleYOLOX : public PPYOLOE { - public: - PaddleYOLOX(const std::string& model_file, const std::string& params_file, - const std::string& config_file, - const RuntimeOption& custom_option = RuntimeOption(), - const ModelFormat& model_format = ModelFormat::PADDLE); - - virtual bool Preprocess(Mat* mat, std::vector* outputs); - - virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; } -}; -} // namespace detection -} // namespace vision -} // namespace fastdeploy diff --git a/python/fastdeploy/vision/detection/__init__.py b/python/fastdeploy/vision/detection/__init__.py index 47d175af7..6de4a3fa6 100755 --- a/python/fastdeploy/vision/detection/__init__.py +++ b/python/fastdeploy/vision/detection/__init__.py @@ -23,5 +23,4 @@ from .contrib.yolov5lite import YOLOv5Lite from .contrib.yolov6 import YOLOv6 from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT from .contrib.yolov7end2end_ort import YOLOv7End2EndORT -from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN -from .rknpu2 import RKPicoDet +from .ppdet import * diff --git a/python/fastdeploy/vision/detection/ppdet/__init__.py b/python/fastdeploy/vision/detection/ppdet/__init__.py index 4497c75ee..4341a9a1a 100644 --- a/python/fastdeploy/vision/detection/ppdet/__init__.py +++ b/python/fastdeploy/vision/detection/ppdet/__init__.py @@ -19,6 +19,40 @@ from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C +class PaddleDetPreprocessor: + def __init__(self, config_file): + """Create a preprocessor for PaddleDetection Model from configuration file + + :param config_file: (str)Path of configuration file, e.g ppyoloe/infer_cfg.yml + """ + self._preprocessor = C.vision.detection.PaddleDetPreprocessor( + config_file) + + def run(self, input_ims): + """Preprocess input images for PaddleDetection Model + + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor, include image, scale_factor, im_shape + """ + return self._preprocessor.run(input_ims) + + +class PaddleDetPostprocessor: + def __init__(self): + """Create a postprocessor for PaddleDetection Model + + """ + self._postprocessor = C.vision.detection.PaddleDetPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for PaddleDetection Model + + :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime + :return: list of ClassifyResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + class PPYOLOE(FastDeployModel): def __init__(self, model_file, @@ -52,6 +86,31 @@ class PPYOLOE(FastDeployModel): assert im is not None, "The input image data is None." return self._model.predict(im) + def batch_predict(self, images): + """Detect a batch of input image list + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of DetectionResult + """ + + return self._model.batch_predict(images) + + @property + def preprocessor(self): + """Get PaddleDetPreprocessor object of the loaded model + + :return PaddleDetPreprocessor + """ + return self._model.preprocessor + + @property + def postprocessor(self): + """Get PaddleDetPostprocessor object of the loaded model + + :return PaddleDetPostprocessor + """ + return self._model.postprocessor + class PPYOLO(PPYOLOE): def __init__(self, @@ -77,31 +136,6 @@ class PPYOLO(PPYOLOE): assert self.initialized, "PPYOLO model initialize failed." -class PPYOLOv2(PPYOLOE): - def __init__(self, - model_file, - params_file, - config_file, - runtime_option=None, - model_format=ModelFormat.PADDLE): - """Load a PPYOLOv2 model exported by PaddleDetection. - - :param model_file: (str)Path of model file, e.g ppyolov2/model.pdmodel - :param params_file: (str)Path of parameters file, e.g ppyolov2/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string - :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml - :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU - :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model - """ - - super(PPYOLOE, self).__init__(runtime_option) - - assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now." - self._model = C.vision.detection.PPYOLOv2( - model_file, params_file, config_file, self._runtime_option, - model_format) - assert self.initialized, "PPYOLOv2 model initialize failed." - - class PaddleYOLOX(PPYOLOE): def __init__(self, model_file, @@ -202,7 +236,7 @@ class YOLOv3(PPYOLOE): assert self.initialized, "YOLOv3 model initialize failed." -class MaskRCNN(FastDeployModel): +class MaskRCNN(PPYOLOE): def __init__(self, model_file, params_file, @@ -211,14 +245,14 @@ class MaskRCNN(FastDeployModel): model_format=ModelFormat.PADDLE): """Load a MaskRCNN model exported by PaddleDetection. - :param model_file: (str)Path of model file, e.g maskrcnn/model.pdmodel - :param params_file: (str)Path of parameters file, e.g maskrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param model_file: (str)Path of model file, e.g fasterrcnn/model.pdmodel + :param params_file: (str)Path of parameters file, e.g fasterrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model """ - super(MaskRCNN, self).__init__(runtime_option) + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now." self._model = C.vision.detection.MaskRCNN( @@ -226,6 +260,12 @@ class MaskRCNN(FastDeployModel): model_format) assert self.initialized, "MaskRCNN model initialize failed." - def predict(self, input_image): - assert input_image is not None, "The input image data is None." - return self._model.predict(input_image) + def batch_predict(self, images): + """Detect a batch of input image list, batch_predict is not supported for maskrcnn now. + + :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return list of DetectionResult + """ + + raise Exception( + "batch_predict is not supported for MaskRCNN model now.") diff --git a/tests/models/test_faster_rcnn.py b/tests/models/test_faster_rcnn.py new file mode 100755 index 000000000..b7ab217a2 --- /dev/null +++ b/tests/models/test_faster_rcnn.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +print(fd.__path__) +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_faster_rcnn(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/faster_rcnn_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/faster_rcnn_r50_vd_fpn_2x_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_paddle_backend() + model = fd.vision.detection.FasterRCNN( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/faster_rcnn_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.0 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("faster_rcnn_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_faster_rcnn() diff --git a/tests/models/test_mask_rcnn.py b/tests/models/test_mask_rcnn.py new file mode 100755 index 000000000..df8641af1 --- /dev/null +++ b/tests/models/test_mask_rcnn.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_mask_rcnn(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/mask_rcnn_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/mask_rcnn_r50_1x_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + model = fd.vision.detection.MaskRCNN( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/mask_rcnn_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.0 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("mask_rcnn_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_mask_rcnn() diff --git a/tests/models/test_picodet.py b/tests/models/test_picodet.py new file mode 100755 index 000000000..59e4bc784 --- /dev/null +++ b/tests/models/test_picodet.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_picodet(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/picodet_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/picodet_l_320_coco_lcnet" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_paddle_backend() + model = fd.vision.detection.PicoDet( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/picodet_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.0 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-02, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("picodet_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_picodet() diff --git a/tests/models/test_pp_yolox.py b/tests/models/test_pp_yolox.py new file mode 100755 index 000000000..57aee0b99 --- /dev/null +++ b/tests/models/test_pp_yolox.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +print(fd.__path__) +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_yolox(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolox_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/yolox_s_300e_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_ort_backend() + model = fd.vision.detection.PaddleYOLOX( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/ppyolox_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.1 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("ppyolox_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_yolox() diff --git a/tests/models/test_ppyolo.py b/tests/models/test_ppyolo.py new file mode 100755 index 000000000..ede1c1550 --- /dev/null +++ b/tests/models/test_ppyolo.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_ppyolo(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolo_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/ppyolov2_r101vd_dcn_365e_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_paddle_backend() + model = fd.vision.detection.PPYOLO( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/ppyolo_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.0 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("ppyolo_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_ppyolo() diff --git a/tests/models/test_ppyoloe.py b/tests/models/test_ppyoloe.py new file mode 100755 index 000000000..b75f34670 --- /dev/null +++ b/tests/models/test_ppyoloe.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_ppyoloe(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyoloe_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/ppyoloe_crn_l_300e_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_ort_backend() + model = fd.vision.detection.PPYOLOE( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/ppyoloe_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.0 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# with open("ppyoloe_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_ppyoloe() diff --git a/tests/models/test_yolov3.py b/tests/models/test_yolov3.py new file mode 100755 index 000000000..e9c2faa3e --- /dev/null +++ b/tests/models/test_yolov3.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np +import runtime_config as rc + + +def test_detection_yolov3(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz" + input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" + result_url = "https://bj.bcebos.com/fastdeploy/tests/data/yolov3_baseline.pkl" + fd.download_and_decompress(model_url, "resources") + fd.download(input_url1, "resources") + fd.download(result_url, "resources") + model_path = "resources/yolov3_darknet53_270e_coco" + + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + rc.test_option.use_ort_backend() + model = fd.vision.detection.YOLOv3( + model_file, params_file, config_file, runtime_option=rc.test_option) + + # compare diff + im1 = cv2.imread("./resources/000000014439.jpg") + for i in range(2): + result = model.predict(im1) + with open("resources/yolov3_baseline.pkl", "rb") as f: + boxes, scores, label_ids = pickle.load(f) + pred_boxes = np.array(result.boxes) + pred_scores = np.array(result.scores) + pred_label_ids = np.array(result.label_ids) + + diff_boxes = np.fabs(boxes - pred_boxes) + diff_scores = np.fabs(scores - pred_scores) + diff_label_ids = np.fabs(label_ids - pred_label_ids) + + print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max()) + + score_threshold = 0.1 + assert diff_boxes[scores > score_threshold].max( + ) < 1e-04, "There's diff in boxes." + assert diff_scores[scores > score_threshold].max( + ) < 1e-04, "There's diff in scores." + assert diff_label_ids[scores > score_threshold].max( + ) < 1e-04, "There's diff in label_ids." + + +# result = model.predict(im1) +# with open("yolov3_baseline.pkl", "wb") as f: +# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) + +if __name__ == "__main__": + test_detection_yolov3()