mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] Refactor PaddleDetection module (#575)
* Add namespace for functions * Refactor PaddleDetection module * finish all the single image test * Update preprocessor.cc * fix some litte detail * add python api * Update postprocessor.cc
This commit is contained in:
@@ -66,4 +66,3 @@ print(result)
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
||||
print(runtime_option)
|
||||
|
@@ -17,7 +17,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace pipeline {
|
||||
PPTinyPose::PPTinyPose(
|
||||
fastdeploy::vision::detection::PPYOLOE* det_model,
|
||||
fastdeploy::vision::detection::PicoDet* det_model,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model)
|
||||
: detector_(det_model), pptinypose_model_(pptinypose_model) {}
|
||||
|
||||
|
@@ -35,7 +35,7 @@ class FASTDEPLOY_DECL PPTinyPose {
|
||||
* \param[in] pptinypose_model Initialized pptinypose model object
|
||||
*/
|
||||
PPTinyPose(
|
||||
fastdeploy::vision::detection::PPYOLOE* det_model,
|
||||
fastdeploy::vision::detection::PicoDet* det_model,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model);
|
||||
|
||||
/** \brief Predict the keypoint detection result for an input image
|
||||
@@ -52,7 +52,7 @@ class FASTDEPLOY_DECL PPTinyPose {
|
||||
float detection_model_score_threshold = 0;
|
||||
|
||||
protected:
|
||||
fastdeploy::vision::detection::PPYOLOE* detector_ = nullptr;
|
||||
fastdeploy::vision::detection::PicoDet* detector_ = nullptr;
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose* pptinypose_model_ =
|
||||
nullptr;
|
||||
|
||||
|
@@ -18,31 +18,8 @@ namespace fastdeploy {
|
||||
void BindPPTinyPosePipeline(pybind11::module& m) {
|
||||
pybind11::class_<pipeline::PPTinyPose>(m, "PPTinyPose")
|
||||
|
||||
// explicitly pybind more kinds of detection models here
|
||||
.def(pybind11::init<fastdeploy::vision::detection::PPYOLOE*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::PicoDet*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::PPYOLO*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::PPYOLOv2*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::PaddleYOLOX*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::FasterRCNN*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::YOLOv3*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::detection::MaskRCNN*,
|
||||
fastdeploy::vision::keypointdetection::PPTinyPose*>())
|
||||
|
||||
.def("predict", [](pipeline::PPTinyPose& self,
|
||||
pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
|
@@ -29,7 +29,6 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolox.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/model.h"
|
||||
#include "fastdeploy/vision/detection/contrib/rknpu2/model.h"
|
||||
#include "fastdeploy/vision/facedet/contrib/retinaface.h"
|
||||
#include "fastdeploy/vision/facedet/contrib/scrfd.h"
|
||||
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
|
||||
|
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"
|
@@ -1,29 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindRKDet(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::RKPicoDet, FastDeployModel>(m, "RKPicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::RKPicoDet& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -1,201 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
RKPicoDet::RKPicoDet(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::ORT};
|
||||
valid_rknpu_backends = {Backend::RKNPU2};
|
||||
if ((model_format == ModelFormat::RKNN) ||
|
||||
(model_format == ModelFormat::ONNX)) {
|
||||
has_nms_ = false;
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
|
||||
// NMS parameters come from RKPicoDet_s_nms
|
||||
background_label = -1;
|
||||
keep_top_k = 100;
|
||||
nms_eta = 1;
|
||||
nms_threshold = 0.5;
|
||||
nms_top_k = 1000;
|
||||
normalized = true;
|
||||
score_threshold = 0.3;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool RKPicoDet::Initialize() {
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RKPicoDet::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Cast::Run(mat, "float");
|
||||
|
||||
scale_factor.resize(2);
|
||||
scale_factor[0] = mat->Height() * 1.0 / origin_h;
|
||||
scale_factor[1] = mat->Width() * 1.0 / origin_w;
|
||||
|
||||
outputs->resize(1);
|
||||
(*outputs)[0].name = InputInfoOfRuntime(0).name;
|
||||
mat->ShareWithTensor(&((*outputs)[0]));
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RKPicoDet::BuildPreprocessPipelineFromConfig() {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||
|
||||
for (const auto& op : cfg["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "NormalizeImage") {
|
||||
continue;
|
||||
} else if (op_name == "Resize") {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
int width = target_size[1];
|
||||
int height = target_size[0];
|
||||
processors_.push_back(
|
||||
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
|
||||
} else {
|
||||
int min_target_size = std::min(target_size[0], target_size[1]);
|
||||
int max_target_size = std::max(target_size[0], target_size[1]);
|
||||
std::vector<int> max_size;
|
||||
if (max_target_size > 0) {
|
||||
max_size.push_back(max_target_size);
|
||||
max_size.push_back(max_target_size);
|
||||
}
|
||||
processors_.push_back(std::make_shared<ResizeByShort>(
|
||||
min_target_size, interp, true, max_size));
|
||||
}
|
||||
} else if (op_name == "Permute") {
|
||||
continue;
|
||||
} else if (op_name == "Pad") {
|
||||
auto size = op["size"].as<std::vector<int>>();
|
||||
auto value = op["fill_value"].as<std::vector<float>>();
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(
|
||||
std::make_shared<PadToSize>(size[1], size[0], value));
|
||||
} else if (op_name == "PadStride") {
|
||||
auto stride = op["stride"].as<int>();
|
||||
processors_.push_back(
|
||||
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
|
||||
} else {
|
||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RKPicoDet::Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result) {
|
||||
FDASSERT(infer_result[1].shape[0] == 1,
|
||||
"Only support batch = 1 in FastDeploy now.");
|
||||
|
||||
if (!has_nms_) {
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
|
||||
boxes_index = 0;
|
||||
scores_index = 1;
|
||||
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
|
||||
boxes_index = 1;
|
||||
scores_index = 0;
|
||||
} else {
|
||||
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
|
||||
"4], [batch, classes_num, boxes_num]"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
backend::MultiClassNMS nms;
|
||||
nms.background_label = background_label;
|
||||
nms.keep_top_k = keep_top_k;
|
||||
nms.nms_eta = nms_eta;
|
||||
nms.nms_threshold = nms_threshold;
|
||||
nms.score_threshold = score_threshold;
|
||||
nms.nms_top_k = nms_top_k;
|
||||
nms.normalized = normalized;
|
||||
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
|
||||
static_cast<float*>(infer_result[scores_index].Data()),
|
||||
infer_result[boxes_index].shape,
|
||||
infer_result[scores_index].shape);
|
||||
if (nms.out_num_rois_data[0] > 0) {
|
||||
result->Reserve(nms.out_num_rois_data[0]);
|
||||
}
|
||||
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
|
||||
result->label_ids.push_back(nms.out_box_data[i * 6]);
|
||||
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
|
||||
result->boxes.emplace_back(
|
||||
std::array<float, 4>{nms.out_box_data[i * 6 + 2] / scale_factor[1],
|
||||
nms.out_box_data[i * 6 + 3] / scale_factor[0],
|
||||
nms.out_box_data[i * 6 + 4] / scale_factor[1],
|
||||
nms.out_box_data[i * 6 + 5] / scale_factor[0]});
|
||||
}
|
||||
} else {
|
||||
FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,46 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE {
|
||||
public:
|
||||
RKPicoDet(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::RKNN);
|
||||
|
||||
virtual std::string ModelName() const { return "RKPicoDet"; }
|
||||
|
||||
protected:
|
||||
/// Build the preprocess pipeline from the loaded model
|
||||
virtual bool BuildPreprocessPipelineFromConfig();
|
||||
/// Preprocess an input image, and set the preprocessed results to `outputs`
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
|
||||
/// Postprocess the inferenced results, and set the final result to `result`
|
||||
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result);
|
||||
virtual bool Initialize();
|
||||
private:
|
||||
std::vector<float> scale_factor{1.0, 1.0};
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -27,8 +27,6 @@ void BindNanoDetPlus(pybind11::module& m);
|
||||
void BindPPDet(pybind11::module& m);
|
||||
void BindYOLOv7End2EndTRT(pybind11::module& m);
|
||||
void BindYOLOv7End2EndORT(pybind11::module& m);
|
||||
void BindRKDet(pybind11::module& m);
|
||||
|
||||
|
||||
void BindDetection(pybind11::module& m) {
|
||||
auto detection_module =
|
||||
@@ -44,6 +42,5 @@ void BindDetection(pybind11::module& m) {
|
||||
BindNanoDetPlus(detection_module);
|
||||
BindYOLOv7End2EndTRT(detection_module);
|
||||
BindYOLOv7End2EndORT(detection_module);
|
||||
BindRKDet(detection_module);
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
68
fastdeploy/vision/detection/ppdet/base.cc
Executable file
68
fastdeploy/vision/detection/ppdet/base.cc
Executable file
@@ -0,0 +1,68 @@
|
||||
#include "fastdeploy/vision/detection/ppdet/base.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) : preprocessor_(config_file) {
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
}
|
||||
|
||||
bool PPDetBase::Initialize() {
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPDetBase::Predict(cv::Mat* im, DetectionResult* result) {
|
||||
return Predict(*im, result);
|
||||
}
|
||||
|
||||
bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) {
|
||||
std::vector<DetectionResult> results;
|
||||
if (!BatchPredict({im}, &results)) {
|
||||
return false;
|
||||
}
|
||||
*result = std::move(results[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<DetectionResult>* results) {
|
||||
std::vector<FDMat> fd_images = WrapMat(imgs);
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
reused_input_tensors_[0].name = "image";
|
||||
reused_input_tensors_[1].name = "scale_factor";
|
||||
reused_input_tensors_[2].name = "im_shape";
|
||||
// Some models don't need im_shape as input
|
||||
if (NumInputsOfRuntime() == 2) {
|
||||
reused_input_tensors_.pop_back();
|
||||
}
|
||||
|
||||
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results)) {
|
||||
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -14,6 +14,8 @@
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
@@ -26,9 +28,9 @@ namespace vision {
|
||||
*/
|
||||
namespace detection {
|
||||
|
||||
/*! @brief PPYOLOE model object used when to load a PPYOLOE model exported by PaddleDetection
|
||||
/*! @brief Base model object used when to load a model exported by PaddleDetection
|
||||
*/
|
||||
class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
|
||||
class FASTDEPLOY_DECL PPDetBase : public FastDeployModel {
|
||||
public:
|
||||
/** \brief Set path of model file and configuration file, and the configuration of runtime
|
||||
*
|
||||
@@ -38,49 +40,49 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
|
||||
* \param[in] model_format Model format of the loaded model, default is Paddle format
|
||||
*/
|
||||
PPYOLOE(const std::string& model_file, const std::string& params_file,
|
||||
PPDetBase(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
/// Get model's name
|
||||
virtual std::string ModelName() const { return "PaddleDetection/PPYOLOE"; }
|
||||
virtual std::string ModelName() const { return "PaddleDetection/BaseModel"; }
|
||||
|
||||
/** \brief Predict the detection result for an input image
|
||||
/** \brief DEPRECATED Predict the detection result for an input image
|
||||
*
|
||||
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
* \param[in] result The output detection result will be writen to this structure
|
||||
* \param[in] result The output detection result
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result);
|
||||
|
||||
/** \brief Predict the detection result for an input image
|
||||
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
* \param[in] result The output detection result
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool Predict(const cv::Mat& im, DetectionResult* result);
|
||||
|
||||
/** \brief Predict the detection result for an input image list
|
||||
* \param[in] im The input image list, all the elements come from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
* \param[in] results The output detection result list
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
||||
PaddleDetPreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
PaddleDetPostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
protected:
|
||||
PPYOLOE() {}
|
||||
virtual bool Initialize();
|
||||
/// Build the preprocess pipeline from the loaded model
|
||||
virtual bool BuildPreprocessPipelineFromConfig();
|
||||
/// Preprocess an input image, and set the preprocessed results to `outputs`
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
|
||||
/// Postprocess the inferenced results, and set the final result to `result`
|
||||
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result);
|
||||
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
std::string config_file_;
|
||||
// configuration for nms
|
||||
int64_t background_label = -1;
|
||||
int64_t keep_top_k = 300;
|
||||
float nms_eta = 1.0;
|
||||
float nms_threshold = 0.7;
|
||||
float score_threshold = 0.01;
|
||||
int64_t nms_top_k = 10000;
|
||||
bool normalized = true;
|
||||
bool has_nms_ = true;
|
||||
|
||||
// This function will used to check if this model contains multiclass_nms
|
||||
// and get parameters from the operator
|
||||
void GetNmsInfo();
|
||||
PaddleDetPreprocessor preprocessor_;
|
||||
PaddleDetPostprocessor postprocessor_;
|
||||
};
|
||||
|
||||
} // namespace detection
|
@@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
MaskRCNN::MaskRCNN(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool MaskRCNN::Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result) {
|
||||
// index 0: bbox_data [N, 6] float32
|
||||
// index 1: bbox_num [B=1] int32
|
||||
// index 2: mask_data [N, h, w] int32
|
||||
FDASSERT(infer_result[1].shape[0] == 1,
|
||||
"Only support batch = 1 in FastDeploy now.");
|
||||
FDASSERT(infer_result.size() == 3,
|
||||
"The infer_result must contains 3 otuput Tensors, but found %lu",
|
||||
infer_result.size());
|
||||
|
||||
FDTensor& box_tensor = infer_result[0];
|
||||
FDTensor& box_num_tensor = infer_result[1];
|
||||
FDTensor& mask_tensor = infer_result[2];
|
||||
|
||||
int box_num = 0;
|
||||
if (box_num_tensor.dtype == FDDataType::INT32) {
|
||||
box_num = *(static_cast<int32_t*>(box_num_tensor.Data()));
|
||||
} else if (box_num_tensor.dtype == FDDataType::INT64) {
|
||||
box_num = *(static_cast<int64_t*>(box_num_tensor.Data()));
|
||||
} else {
|
||||
FDASSERT(false,
|
||||
"The output box_num of PaddleDetection/MaskRCNN model should be "
|
||||
"type of int32/int64.");
|
||||
}
|
||||
if (box_num <= 0) {
|
||||
return true; // no object detected.
|
||||
}
|
||||
result->Resize(box_num);
|
||||
float* box_data = static_cast<float*>(box_tensor.Data());
|
||||
for (size_t i = 0; i < box_num; ++i) {
|
||||
result->label_ids[i] = static_cast<int>(box_data[i * 6]);
|
||||
result->scores[i] = box_data[i * 6 + 1];
|
||||
result->boxes[i] =
|
||||
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
|
||||
box_data[i * 6 + 4], box_data[i * 6 + 5]};
|
||||
}
|
||||
result->contain_masks = true;
|
||||
// TODO(qiuyanjun): Cast int64/int8 to int32.
|
||||
FDASSERT(mask_tensor.dtype == FDDataType::INT32,
|
||||
"The dtype of mask Tensor must be int32 now!");
|
||||
// In PaddleDetection/MaskRCNN, the mask_h and mask_w
|
||||
// are already aligned with original input image. So,
|
||||
// we need to crop it from output mask according to
|
||||
// the detected bounding box.
|
||||
// +-----------------------+
|
||||
// | x1,y1 |
|
||||
// | +---------------+ |
|
||||
// | | | |
|
||||
// | | Crop | |
|
||||
// | | | |
|
||||
// | | | |
|
||||
// | +---------------+ |
|
||||
// | x2,y2 |
|
||||
// +-----------------------+
|
||||
int64_t out_mask_h = mask_tensor.shape[1];
|
||||
int64_t out_mask_w = mask_tensor.shape[2];
|
||||
int64_t out_mask_numel = out_mask_h * out_mask_w;
|
||||
int32_t* out_mask_data = static_cast<int32_t*>(mask_tensor.Data());
|
||||
for (size_t i = 0; i < box_num; ++i) {
|
||||
// crop instance mask according to box
|
||||
int64_t x1 = static_cast<int64_t>(result->boxes[i][0]);
|
||||
int64_t y1 = static_cast<int64_t>(result->boxes[i][1]);
|
||||
int64_t x2 = static_cast<int64_t>(result->boxes[i][2]);
|
||||
int64_t y2 = static_cast<int64_t>(result->boxes[i][3]);
|
||||
int64_t keep_mask_h = y2 - y1;
|
||||
int64_t keep_mask_w = x2 - x1;
|
||||
int64_t keep_mask_numel = keep_mask_h * keep_mask_w;
|
||||
result->masks[i].Resize(keep_mask_numel); // int32
|
||||
result->masks[i].shape = {keep_mask_h, keep_mask_w};
|
||||
int32_t* mask_start_ptr = out_mask_data + i * out_mask_numel;
|
||||
int32_t* keep_mask_ptr = static_cast<int32_t*>(result->masks[i].Data());
|
||||
for (size_t row = y1; row < y2; ++row) {
|
||||
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
|
||||
int32_t* out_row_start_ptr = mask_start_ptr + row * out_mask_w + x1;
|
||||
int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w;
|
||||
std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -13,10 +13,152 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/mask_rcnn.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/picodet.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyolo.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/rcnn.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/yolov3.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/yolox.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/base.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL PicoDet : public PPDetBase {
|
||||
public:
|
||||
/** \brief Set path of model file and configuration file, and the configuration of runtime
|
||||
*
|
||||
* \param[in] model_file Path of model file, e.g picodet/model.pdmodel
|
||||
* \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
|
||||
* \param[in] model_format Model format of the loaded model, default is Paddle format
|
||||
*/
|
||||
PicoDet(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
|
||||
Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PicoDet"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
|
||||
public:
|
||||
/** \brief Set path of model file and configuration file, and the configuration of runtime
|
||||
*
|
||||
* \param[in] model_file Path of model file, e.g ppyoloe/model.pdmodel
|
||||
* \param[in] params_file Path of parameter file, e.g picodet/model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
|
||||
* \param[in] model_format Model format of the loaded model, default is Paddle format
|
||||
*/
|
||||
PPYOLOE(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
|
||||
Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PPYOLOE"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PPYOLO : public PPDetBase {
|
||||
public:
|
||||
/** \brief Set path of model file and configuration file, and the configuration of runtime
|
||||
*
|
||||
* \param[in] model_file Path of model file, e.g ppyolo/model.pdmodel
|
||||
* \param[in] params_file Path of parameter file, e.g ppyolo/model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g picodet/infer_cfg.yml
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
|
||||
* \param[in] model_format Model format of the loaded model, default is Paddle format
|
||||
*/
|
||||
PPYOLO(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/PP-YOLO"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL YOLOv3 : public PPDetBase {
|
||||
public:
|
||||
YOLOv3(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PaddleYOLOX : public PPDetBase {
|
||||
public:
|
||||
PaddleYOLOX(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL FasterRCNN : public PPDetBase {
|
||||
public:
|
||||
FasterRCNN(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL MaskRCNN : public PPDetBase {
|
||||
public:
|
||||
MaskRCNN(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; }
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -1,66 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/picodet.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PicoDet::PicoDet(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
background_label = -1;
|
||||
keep_top_k = 100;
|
||||
nms_eta = 1;
|
||||
nms_threshold = 0.6;
|
||||
nms_top_k = 1000;
|
||||
normalized = true;
|
||||
score_threshold = 0.025;
|
||||
CheckIfContainDecodeAndNMS();
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool PicoDet::CheckIfContainDecodeAndNMS() {
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cfg["arch"].as<std::string>() == "PicoDet") {
|
||||
FDERROR << "The arch in config file is PicoDet, which means this model "
|
||||
"doesn contain box decode and nms, please export model with "
|
||||
"decode and nms."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,36 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL PicoDet : public PPYOLOE {
|
||||
public:
|
||||
PicoDet(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
// Only support picodet contains decode and nms
|
||||
bool CheckIfContainDecodeAndNMS();
|
||||
|
||||
virtual std::string ModelName() const { return "PicoDet"; }
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
132
fastdeploy/vision/detection/ppdet/postprocessor.cc
Normal file
132
fastdeploy/vision/detection/ppdet/postprocessor.cc
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<DetectionResult>* results) {
|
||||
auto shape = tensor.Shape();
|
||||
if (tensor.Dtype() != FDDataType::INT32) {
|
||||
FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl;
|
||||
return false;
|
||||
}
|
||||
int64_t out_mask_h = shape[1];
|
||||
int64_t out_mask_w = shape[2];
|
||||
int64_t out_mask_numel = shape[1] * shape[2];
|
||||
const int32_t* data = reinterpret_cast<const int32_t*>(tensor.CpuData());
|
||||
int index = 0;
|
||||
|
||||
for (int i = 0; i < results->size(); ++i) {
|
||||
(*results)[i].contain_masks = true;
|
||||
(*results)[i].masks.resize((*results)[i].boxes.size());
|
||||
for (int j = 0; j < (*results)[i].boxes.size(); ++j) {
|
||||
int x1 = static_cast<int>((*results)[i].boxes[j][0]);
|
||||
int y1 = static_cast<int>((*results)[i].boxes[j][1]);
|
||||
int x2 = static_cast<int>((*results)[i].boxes[j][2]);
|
||||
int y2 = static_cast<int>((*results)[i].boxes[j][3]);
|
||||
int keep_mask_h = y2 - y1;
|
||||
int keep_mask_w = x2 - x1;
|
||||
int keep_mask_numel = keep_mask_h * keep_mask_w;
|
||||
(*results)[i].masks[j].Resize(keep_mask_numel);
|
||||
(*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w};
|
||||
const int32_t* current_ptr = data + index * out_mask_numel;
|
||||
|
||||
int32_t* keep_mask_ptr = reinterpret_cast<int32_t*>((*results)[i].masks[j].Data());
|
||||
for (int row = y1; row < y2; ++row) {
|
||||
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
|
||||
const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1;
|
||||
int32_t* keep_row_start_ptr = keep_mask_ptr + (row - y1) * keep_mask_w;
|
||||
std::memcpy(keep_row_start_ptr, out_row_start_ptr, keep_nbytes_in_col);
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vector<DetectionResult>* results) {
|
||||
if (tensors[0].shape[0] == 0) {
|
||||
// No detected boxes
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get number of boxes for each input image
|
||||
std::vector<int> num_boxes(tensors[1].shape[0]);
|
||||
int total_num_boxes = 0;
|
||||
if (tensors[1].dtype == FDDataType::INT32) {
|
||||
const int32_t* data = static_cast<const int32_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
total_num_boxes += num_boxes[i];
|
||||
}
|
||||
} else if (tensors[1].dtype == FDDataType::INT64) {
|
||||
const int64_t* data = static_cast<const int64_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for TensorRT, it has fixed output shape of NMS
|
||||
// So there's invalid boxes in its' output boxes
|
||||
int num_output_boxes = tensors[0].Shape()[0];
|
||||
bool contain_invalid_boxes = false;
|
||||
if (total_num_boxes != num_output_boxes) {
|
||||
if (num_output_boxes % num_boxes.size() == 0) {
|
||||
contain_invalid_boxes = true;
|
||||
} else {
|
||||
FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
const float* box_data = static_cast<const float*>(tensors[0].CpuData());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(static_cast<int32_t>(round(ptr[j * 6])));
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
}
|
||||
if (contain_invalid_boxes) {
|
||||
offset += (num_output_boxes * 6 / num_boxes.size());
|
||||
} else {
|
||||
offset += (num_boxes[i] * 6);
|
||||
}
|
||||
}
|
||||
|
||||
// Only detection
|
||||
if (tensors.size() <= 2) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (tensors[2].Shape()[0] != num_output_boxes) {
|
||||
FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// process for maskrcnn
|
||||
return ProcessMask(tensors[2], results);
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -13,26 +13,30 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/rcnn.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL MaskRCNN : public FasterRCNN {
|
||||
/*! @brief Postprocessor object for PaddleDet serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
public:
|
||||
MaskRCNN(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; }
|
||||
|
||||
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result);
|
||||
|
||||
protected:
|
||||
MaskRCNN() {}
|
||||
PaddleDetPostprocessor() = default;
|
||||
/** \brief Process the result of runtime and fill to ClassifyResult structure
|
||||
*
|
||||
* \param[in] tensors The inference result from runtime
|
||||
* \param[in] result The output result of detection
|
||||
* \return true if the postprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* result);
|
||||
private:
|
||||
// Process mask tensor for MaskRCNN
|
||||
bool ProcessMask(const FDTensor& tensor,
|
||||
std::vector<DetectionResult>* results);
|
||||
};
|
||||
|
||||
} // namespace detection
|
@@ -15,94 +15,94 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPPDet(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::PPYOLOE, FastDeployModel>(m, "PPYOLOE")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLOE& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
pybind11::class_<vision::detection::PaddleDetPreprocessor>(
|
||||
m, "PaddleDetPreprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run", [](vision::detection::PaddleDetPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleDetPreprocessor.')");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
return outputs;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLO, FastDeployModel>(m, "PPYOLO")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLO& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
pybind11::class_<vision::detection::PaddleDetPostprocessor>(
|
||||
m, "PaddleDetPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<FDTensor>& inputs) {
|
||||
std::vector<vision::DetectionResult> results;
|
||||
if (!self.Run(inputs, &results)) {
|
||||
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) {
|
||||
std::vector<vision::DetectionResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results)) {
|
||||
pybind11::eval("raise Exception('Failed to postprocess the runtime result in PaddleDetPostprocessor.')");
|
||||
}
|
||||
return results;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLOv2, FastDeployModel>(m, "PPYOLOv2")
|
||||
pybind11::class_<vision::detection::PPDetBase, FastDeployModel>(m, "PPDetBase")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::PPYOLOv2& self, pybind11::array& data) {
|
||||
[](vision::detection::PPDetBase& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
})
|
||||
.def("batch_predict",
|
||||
[](vision::detection::PPDetBase& self, std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
}
|
||||
std::vector<vision::DetectionResult> results;
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("preprocessor", &vision::detection::PPDetBase::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::detection::PPDetBase::GetPostprocessor);
|
||||
|
||||
pybind11::class_<vision::detection::PicoDet, FastDeployModel>(m, "PicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::PicoDet& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::detection::PaddleYOLOX, FastDeployModel>(
|
||||
m, "PaddleYOLOX")
|
||||
pybind11::class_<vision::detection::PPYOLO, vision::detection::PPDetBase>(m, "PPYOLO")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::PaddleYOLOX& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::FasterRCNN, FastDeployModel>(m,
|
||||
"FasterRCNN")
|
||||
pybind11::class_<vision::detection::PPYOLOE, vision::detection::PPDetBase>(m, "PPYOLOE")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::FasterRCNN& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::YOLOv3, FastDeployModel>(m, "YOLOv3")
|
||||
pybind11::class_<vision::detection::PicoDet, vision::detection::PPDetBase>(m, "PicoDet")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::YOLOv3& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::MaskRCNN, FastDeployModel>(m, "MaskRCNN")
|
||||
pybind11::class_<vision::detection::PaddleYOLOX, vision::detection::PPDetBase>(m, "PaddleYOLOX")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::detection::MaskRCNN& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::FasterRCNN, vision::detection::PPDetBase>(m, "FasterRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::YOLOv3, vision::detection::PPDetBase>(m, "YOLOv3")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::MaskRCNN, vision::detection::PPDetBase>(m, "MaskRCNN")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
@@ -1,78 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyolo.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PPYOLO::PPYOLO(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
has_nms_ = true;
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool PPYOLO::Initialize() {
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPYOLO::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
|
||||
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
|
||||
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
|
||||
ptr0[0] = mat->Height();
|
||||
ptr0[1] = mat->Width();
|
||||
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
|
||||
ptr2[0] = mat->Height() * 1.0 / origin_h;
|
||||
ptr2[1] = mat->Width() * 1.0 / origin_w;
|
||||
(*outputs)[1].name = "image";
|
||||
mat->ShareWithTensor(&((*outputs)[1]));
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,52 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL PPYOLO : public PPYOLOE {
|
||||
public:
|
||||
PPYOLO(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/PPYOLO"; }
|
||||
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
virtual bool Initialize();
|
||||
|
||||
protected:
|
||||
PPYOLO() {}
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PPYOLOv2 : public PPYOLO {
|
||||
public:
|
||||
PPYOLOv2(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPYOLO(model_file, params_file, config_file, custom_option,
|
||||
model_format) {}
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/PPYOLOv2"; }
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,274 +0,0 @@
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
#include "paddle2onnx/converter.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
void PPYOLOE::GetNmsInfo() {
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
if (runtime_option.model_format == ModelFormat::PADDLE) {
|
||||
std::string contents;
|
||||
if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) {
|
||||
return;
|
||||
}
|
||||
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
|
||||
if (reader.has_nms) {
|
||||
has_nms_ = true;
|
||||
background_label = reader.nms_params.background_label;
|
||||
keep_top_k = reader.nms_params.keep_top_k;
|
||||
nms_eta = reader.nms_params.nms_eta;
|
||||
nms_threshold = reader.nms_params.nms_threshold;
|
||||
score_threshold = reader.nms_params.score_threshold;
|
||||
nms_top_k = reader.nms_params.nms_top_k;
|
||||
normalized = reader.nms_params.normalized;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool PPYOLOE::Initialize() {
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
reused_input_tensors_.resize(2);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||
|
||||
bool has_permute = false;
|
||||
for (const auto& op : cfg["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "NormalizeImage") {
|
||||
auto mean = op["mean"].as<std::vector<float>>();
|
||||
auto std = op["std"].as<std::vector<float>>();
|
||||
bool is_scale = true;
|
||||
if (op["is_scale"]) {
|
||||
is_scale = op["is_scale"].as<bool>();
|
||||
}
|
||||
std::string norm_type = "mean_std";
|
||||
if (op["norm_type"]) {
|
||||
norm_type = op["norm_type"].as<std::string>();
|
||||
}
|
||||
if (norm_type != "mean_std") {
|
||||
std::fill(mean.begin(), mean.end(), 0.0);
|
||||
std::fill(std.begin(), std.end(), 1.0);
|
||||
}
|
||||
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
|
||||
} else if (op_name == "Resize") {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
int width = target_size[1];
|
||||
int height = target_size[0];
|
||||
processors_.push_back(
|
||||
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
|
||||
} else {
|
||||
int min_target_size = std::min(target_size[0], target_size[1]);
|
||||
int max_target_size = std::max(target_size[0], target_size[1]);
|
||||
std::vector<int> max_size;
|
||||
if (max_target_size > 0) {
|
||||
max_size.push_back(max_target_size);
|
||||
max_size.push_back(max_target_size);
|
||||
}
|
||||
processors_.push_back(std::make_shared<ResizeByShort>(
|
||||
min_target_size, interp, true, max_size));
|
||||
}
|
||||
} else if (op_name == "Permute") {
|
||||
// Do nothing, do permute as the last operation
|
||||
has_permute = true;
|
||||
continue;
|
||||
// processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
} else if (op_name == "Pad") {
|
||||
auto size = op["size"].as<std::vector<int>>();
|
||||
auto value = op["fill_value"].as<std::vector<float>>();
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(
|
||||
std::make_shared<PadToSize>(size[1], size[0], value));
|
||||
} else if (op_name == "PadStride") {
|
||||
auto stride = op["stride"].as<int>();
|
||||
processors_.push_back(
|
||||
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
|
||||
} else {
|
||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (has_permute) {
|
||||
// permute = cast<float> + HWC2CHW
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
} else {
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(2);
|
||||
(*outputs)[0].name = InputInfoOfRuntime(0).name;
|
||||
mat->ShareWithTensor(&((*outputs)[0]));
|
||||
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
|
||||
|
||||
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
|
||||
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
|
||||
ptr[0] = mat->Height() * 1.0 / origin_h;
|
||||
ptr[1] = mat->Width() * 1.0 / origin_w;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPYOLOE::Postprocess(std::vector<FDTensor>& infer_result,
|
||||
DetectionResult* result) {
|
||||
FDASSERT(infer_result[1].shape[0] == 1,
|
||||
"Only support batch = 1 in FastDeploy now.");
|
||||
|
||||
has_nms_ = true;
|
||||
if (!has_nms_) {
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
|
||||
boxes_index = 0;
|
||||
scores_index = 1;
|
||||
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
|
||||
boxes_index = 1;
|
||||
scores_index = 0;
|
||||
} else {
|
||||
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
|
||||
"4], [batch, classes_num, boxes_num]"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
backend::MultiClassNMS nms;
|
||||
nms.background_label = background_label;
|
||||
nms.keep_top_k = keep_top_k;
|
||||
nms.nms_eta = nms_eta;
|
||||
nms.nms_threshold = nms_threshold;
|
||||
nms.score_threshold = score_threshold;
|
||||
nms.nms_top_k = nms_top_k;
|
||||
nms.normalized = normalized;
|
||||
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
|
||||
static_cast<float*>(infer_result[scores_index].Data()),
|
||||
infer_result[boxes_index].shape,
|
||||
infer_result[scores_index].shape);
|
||||
if (nms.out_num_rois_data[0] > 0) {
|
||||
result->Reserve(nms.out_num_rois_data[0]);
|
||||
}
|
||||
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
|
||||
result->label_ids.push_back(nms.out_box_data[i * 6]);
|
||||
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
|
||||
result->boxes.emplace_back(std::array<float, 4>{
|
||||
nms.out_box_data[i * 6 + 2], nms.out_box_data[i * 6 + 3],
|
||||
nms.out_box_data[i * 6 + 4], nms.out_box_data[i * 6 + 5]});
|
||||
}
|
||||
} else {
|
||||
std::vector<int> num_boxes(infer_result[1].shape[0]);
|
||||
if (infer_result[1].dtype == FDDataType::INT32) {
|
||||
int32_t* data = static_cast<int32_t*>(infer_result[1].Data());
|
||||
for (size_t i = 0; i < infer_result[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
}
|
||||
} else if (infer_result[1].dtype == FDDataType::INT64) {
|
||||
int64_t* data = static_cast<int64_t*>(infer_result[1].Data());
|
||||
for (size_t i = 0; i < infer_result[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Only support batch = 1 now
|
||||
result->Reserve(num_boxes[0]);
|
||||
float* box_data = static_cast<float*>(infer_result[0].Data());
|
||||
for (size_t i = 0; i < num_boxes[0]; ++i) {
|
||||
result->label_ids.push_back(box_data[i * 6]);
|
||||
result->scores.push_back(box_data[i * 6 + 1]);
|
||||
result->boxes.emplace_back(
|
||||
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
|
||||
box_data[i * 6 + 4], box_data[i * 6 + 5]});
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) {
|
||||
Mat mat(*im);
|
||||
|
||||
if (!Preprocess(&mat, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess input data while using model:"
|
||||
<< ModelName() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Infer()) {
|
||||
FDERROR << "Failed to inference while using model:" << ModelName() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Postprocess(reused_output_tensors_, result)) {
|
||||
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
201
fastdeploy/vision/detection/ppdet/preprocessor.cc
Normal file
201
fastdeploy/vision/detection/ppdet/preprocessor.cc
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "fastdeploy/function/pad.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) {
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor.");
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||
|
||||
bool has_permute = false;
|
||||
for (const auto& op : cfg["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "NormalizeImage") {
|
||||
auto mean = op["mean"].as<std::vector<float>>();
|
||||
auto std = op["std"].as<std::vector<float>>();
|
||||
bool is_scale = true;
|
||||
if (op["is_scale"]) {
|
||||
is_scale = op["is_scale"].as<bool>();
|
||||
}
|
||||
std::string norm_type = "mean_std";
|
||||
if (op["norm_type"]) {
|
||||
norm_type = op["norm_type"].as<std::string>();
|
||||
}
|
||||
if (norm_type != "mean_std") {
|
||||
std::fill(mean.begin(), mean.end(), 0.0);
|
||||
std::fill(std.begin(), std.end(), 1.0);
|
||||
}
|
||||
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
|
||||
} else if (op_name == "Resize") {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
int width = target_size[1];
|
||||
int height = target_size[0];
|
||||
processors_.push_back(
|
||||
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
|
||||
} else {
|
||||
int min_target_size = std::min(target_size[0], target_size[1]);
|
||||
int max_target_size = std::max(target_size[0], target_size[1]);
|
||||
std::vector<int> max_size;
|
||||
if (max_target_size > 0) {
|
||||
max_size.push_back(max_target_size);
|
||||
max_size.push_back(max_target_size);
|
||||
}
|
||||
processors_.push_back(std::make_shared<ResizeByShort>(
|
||||
min_target_size, interp, true, max_size));
|
||||
}
|
||||
} else if (op_name == "Permute") {
|
||||
// Do nothing, do permute as the last operation
|
||||
has_permute = true;
|
||||
continue;
|
||||
// processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
} else if (op_name == "Pad") {
|
||||
auto size = op["size"].as<std::vector<int>>();
|
||||
auto value = op["fill_value"].as<std::vector<float>>();
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(
|
||||
std::make_shared<PadToSize>(size[1], size[0], value));
|
||||
} else if (op_name == "PadStride") {
|
||||
auto stride = op["stride"].as<int>();
|
||||
processors_.push_back(
|
||||
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
|
||||
} else {
|
||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (has_permute) {
|
||||
// permute = cast<float> + HWC2CHW
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
} else {
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// There are 3 outputs, image, scale_factor, im_shape
|
||||
// But im_shape is not used for all the PaddleDetection models
|
||||
// So preprocessor will output the 3 FDTensors, and how to use `im_shape`
|
||||
// is decided by the model itself
|
||||
outputs->resize(3);
|
||||
int batch = static_cast<int>(images->size());
|
||||
// Allocate memory for scale_factor
|
||||
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
|
||||
// Allocate memory for im_shape
|
||||
(*outputs)[2].Resize({batch, 2}, FDDataType::FP32);
|
||||
// Record the max size for a batch of input image
|
||||
// All the tensor will pad to the max size to compose a batched tensor
|
||||
std::vector<int> max_hw({-1, -1});
|
||||
|
||||
float* scale_factor_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||
float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
int origin_w = (*images)[i].Width();
|
||||
int origin_h = (*images)[i].Height();
|
||||
scale_factor_ptr[2 * i] = 1.0;
|
||||
scale_factor_ptr[2 * i + 1] = 1.0;
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
if (!(*(processors_[j].get()))(&((*images)[i]))) {
|
||||
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (processors_[j]->Name().find("Resize") != std::string::npos) {
|
||||
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
|
||||
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
|
||||
}
|
||||
}
|
||||
if ((*images)[i].Height() > max_hw[0]) {
|
||||
max_hw[0] = (*images)[i].Height();
|
||||
}
|
||||
if ((*images)[i].Width() > max_hw[1]) {
|
||||
max_hw[1] = (*images)[i].Width();
|
||||
}
|
||||
im_shape_ptr[2 * i] = max_hw[0];
|
||||
im_shape_ptr[2 * i + 1] = max_hw[1];
|
||||
}
|
||||
|
||||
// Concat all the preprocessed data to a batch tensor
|
||||
std::vector<FDTensor> im_tensors(images->size());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
|
||||
// if the size of image less than max_hw, pad to max_hw
|
||||
FDTensor tensor;
|
||||
(*images)[i].ShareWithTensor(&tensor);
|
||||
function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0);
|
||||
} else {
|
||||
// No need pad
|
||||
(*images)[i].ShareWithTensor(&(im_tensors[i]));
|
||||
}
|
||||
// Reshape to 1xCxHxW
|
||||
im_tensors[i].ExpandDim(0);
|
||||
}
|
||||
|
||||
if (im_tensors.size() == 1) {
|
||||
// If there's only 1 input, no need to concat
|
||||
// skip memory copy
|
||||
(*outputs)[0] = std::move(im_tensors[0]);
|
||||
} else {
|
||||
// Else concat the im tensor for each input image
|
||||
// compose a batched input tensor
|
||||
function::Concat(im_tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
50
fastdeploy/vision/detection/ppdet/preprocessor.h
Normal file
50
fastdeploy/vision/detection/ppdet/preprocessor.h
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace detection {
|
||||
/*! @brief Preprocessor object for PaddleDet serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL PaddleDetPreprocessor {
|
||||
public:
|
||||
PaddleDetPreprocessor() = default;
|
||||
/** \brief Create a preprocessor instance for PaddleDet serials model
|
||||
*
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
*/
|
||||
explicit PaddleDetPreprocessor(const std::string& config_file);
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime, include image, scale_factor, im_shape
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,84 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/rcnn.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
FasterRCNN::FasterRCNN(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
has_nms_ = true;
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool FasterRCNN::Initialize() {
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FasterRCNN::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
float scale[2] = {1.0, 1.0};
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (processors_[i]->Name().find("Resize") != std::string::npos) {
|
||||
scale[0] = mat->Height() * 1.0 / origin_h;
|
||||
scale[1] = mat->Width() * 1.0 / origin_w;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
|
||||
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
|
||||
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
|
||||
ptr0[0] = mat->Height();
|
||||
ptr0[1] = mat->Width();
|
||||
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
|
||||
ptr2[0] = scale[0];
|
||||
ptr2[1] = scale[1];
|
||||
(*outputs)[1].name = "image";
|
||||
mat->ShareWithTensor(&((*outputs)[1]));
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,39 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL FasterRCNN : public PPYOLOE {
|
||||
public:
|
||||
FasterRCNN(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/FasterRCNN"; }
|
||||
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
virtual bool Initialize();
|
||||
|
||||
protected:
|
||||
FasterRCNN() {}
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,64 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/yolov3.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool YOLOv3::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(3);
|
||||
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, "im_shape");
|
||||
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, "scale_factor");
|
||||
float* ptr0 = static_cast<float*>((*outputs)[0].MutableData());
|
||||
ptr0[0] = mat->Height();
|
||||
ptr0[1] = mat->Width();
|
||||
float* ptr2 = static_cast<float*>((*outputs)[2].MutableData());
|
||||
ptr2[0] = mat->Height() * 1.0 / origin_h;
|
||||
ptr2[1] = mat->Width() * 1.0 / origin_w;
|
||||
(*outputs)[1].name = "image";
|
||||
mat->ShareWithTensor(&((*outputs)[1]));
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[1].shape.insert((*outputs)[1].shape.begin(), 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL YOLOv3 : public PPYOLOE {
|
||||
public:
|
||||
YOLOv3(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/YOLOv3"; }
|
||||
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,74 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/yolox.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
PaddleYOLOX::PaddleYOLOX(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
config_file_ = config_file;
|
||||
valid_cpu_backends = {Backend::ORT, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
background_label = -1;
|
||||
keep_top_k = 1000;
|
||||
nms_eta = 1;
|
||||
nms_threshold = 0.65;
|
||||
nms_top_k = 10000;
|
||||
normalized = true;
|
||||
score_threshold = 0.001;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool PaddleYOLOX::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
int origin_w = mat->Width();
|
||||
int origin_h = mat->Height();
|
||||
float scale[2] = {1.0, 1.0};
|
||||
for (size_t i = 0; i < processors_.size(); ++i) {
|
||||
if (!(*(processors_[i].get()))(mat)) {
|
||||
FDERROR << "Failed to process image data in " << processors_[i]->Name()
|
||||
<< "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (processors_[i]->Name().find("Resize") != std::string::npos) {
|
||||
scale[0] = mat->Height() * 1.0 / origin_h;
|
||||
scale[1] = mat->Width() * 1.0 / origin_w;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(2);
|
||||
(*outputs)[0].name = InputInfoOfRuntime(0).name;
|
||||
mat->ShareWithTensor(&((*outputs)[0]));
|
||||
|
||||
// reshape to [1, c, h, w]
|
||||
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
|
||||
|
||||
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
|
||||
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
|
||||
ptr[0] = scale[0];
|
||||
ptr[1] = scale[1];
|
||||
return true;
|
||||
}
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL PaddleYOLOX : public PPYOLOE {
|
||||
public:
|
||||
PaddleYOLOX(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
|
||||
|
||||
virtual std::string ModelName() const { return "PaddleDetection/YOLOX"; }
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -23,5 +23,4 @@ from .contrib.yolov5lite import YOLOv5Lite
|
||||
from .contrib.yolov6 import YOLOv6
|
||||
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
|
||||
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
|
||||
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN
|
||||
from .rknpu2 import RKPicoDet
|
||||
from .ppdet import *
|
||||
|
@@ -19,6 +19,40 @@ from .... import FastDeployModel, ModelFormat
|
||||
from .... import c_lib_wrap as C
|
||||
|
||||
|
||||
class PaddleDetPreprocessor:
|
||||
def __init__(self, config_file):
|
||||
"""Create a preprocessor for PaddleDetection Model from configuration file
|
||||
|
||||
:param config_file: (str)Path of configuration file, e.g ppyoloe/infer_cfg.yml
|
||||
"""
|
||||
self._preprocessor = C.vision.detection.PaddleDetPreprocessor(
|
||||
config_file)
|
||||
|
||||
def run(self, input_ims):
|
||||
"""Preprocess input images for PaddleDetection Model
|
||||
|
||||
:param: input_ims: (list of numpy.ndarray)The input image
|
||||
:return: list of FDTensor, include image, scale_factor, im_shape
|
||||
"""
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
|
||||
class PaddleDetPostprocessor:
|
||||
def __init__(self):
|
||||
"""Create a postprocessor for PaddleDetection Model
|
||||
|
||||
"""
|
||||
self._postprocessor = C.vision.detection.PaddleDetPostprocessor()
|
||||
|
||||
def run(self, runtime_results):
|
||||
"""Postprocess the runtime results for PaddleDetection Model
|
||||
|
||||
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
|
||||
:return: list of ClassifyResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
|
||||
"""
|
||||
return self._postprocessor.run(runtime_results)
|
||||
|
||||
|
||||
class PPYOLOE(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
@@ -52,6 +86,31 @@ class PPYOLOE(FastDeployModel):
|
||||
assert im is not None, "The input image data is None."
|
||||
return self._model.predict(im)
|
||||
|
||||
def batch_predict(self, images):
|
||||
"""Detect a batch of input image list
|
||||
|
||||
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||
:return list of DetectionResult
|
||||
"""
|
||||
|
||||
return self._model.batch_predict(images)
|
||||
|
||||
@property
|
||||
def preprocessor(self):
|
||||
"""Get PaddleDetPreprocessor object of the loaded model
|
||||
|
||||
:return PaddleDetPreprocessor
|
||||
"""
|
||||
return self._model.preprocessor
|
||||
|
||||
@property
|
||||
def postprocessor(self):
|
||||
"""Get PaddleDetPostprocessor object of the loaded model
|
||||
|
||||
:return PaddleDetPostprocessor
|
||||
"""
|
||||
return self._model.postprocessor
|
||||
|
||||
|
||||
class PPYOLO(PPYOLOE):
|
||||
def __init__(self,
|
||||
@@ -77,31 +136,6 @@ class PPYOLO(PPYOLOE):
|
||||
assert self.initialized, "PPYOLO model initialize failed."
|
||||
|
||||
|
||||
class PPYOLOv2(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.PADDLE):
|
||||
"""Load a PPYOLOv2 model exported by PaddleDetection.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g ppyolov2/model.pdmodel
|
||||
:param params_file: (str)Path of parameters file, e.g ppyolov2/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||
"""
|
||||
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now."
|
||||
self._model = C.vision.detection.PPYOLOv2(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PPYOLOv2 model initialize failed."
|
||||
|
||||
|
||||
class PaddleYOLOX(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
@@ -202,7 +236,7 @@ class YOLOv3(PPYOLOE):
|
||||
assert self.initialized, "YOLOv3 model initialize failed."
|
||||
|
||||
|
||||
class MaskRCNN(FastDeployModel):
|
||||
class MaskRCNN(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
@@ -211,14 +245,14 @@ class MaskRCNN(FastDeployModel):
|
||||
model_format=ModelFormat.PADDLE):
|
||||
"""Load a MaskRCNN model exported by PaddleDetection.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g maskrcnn/model.pdmodel
|
||||
:param params_file: (str)Path of parameters file, e.g maskrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param model_file: (str)Path of model file, e.g fasterrcnn/model.pdmodel
|
||||
:param params_file: (str)Path of parameters file, e.g fasterrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||
"""
|
||||
|
||||
super(MaskRCNN, self).__init__(runtime_option)
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now."
|
||||
self._model = C.vision.detection.MaskRCNN(
|
||||
@@ -226,6 +260,12 @@ class MaskRCNN(FastDeployModel):
|
||||
model_format)
|
||||
assert self.initialized, "MaskRCNN model initialize failed."
|
||||
|
||||
def predict(self, input_image):
|
||||
assert input_image is not None, "The input image data is None."
|
||||
return self._model.predict(input_image)
|
||||
def batch_predict(self, images):
|
||||
"""Detect a batch of input image list, batch_predict is not supported for maskrcnn now.
|
||||
|
||||
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||
:return list of DetectionResult
|
||||
"""
|
||||
|
||||
raise Exception(
|
||||
"batch_predict is not supported for MaskRCNN model now.")
|
||||
|
70
tests/models/test_faster_rcnn.py
Executable file
70
tests/models/test_faster_rcnn.py
Executable file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
print(fd.__path__)
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_faster_rcnn():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/faster_rcnn_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/faster_rcnn_r50_vd_fpn_2x_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_paddle_backend()
|
||||
model = fd.vision.detection.FasterRCNN(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/faster_rcnn_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.0
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("faster_rcnn_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_faster_rcnn()
|
68
tests/models/test_mask_rcnn.py
Executable file
68
tests/models/test_mask_rcnn.py
Executable file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_mask_rcnn():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/mask_rcnn_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/mask_rcnn_r50_1x_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
model = fd.vision.detection.MaskRCNN(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/mask_rcnn_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.0
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("mask_rcnn_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_mask_rcnn()
|
69
tests/models/test_picodet.py
Executable file
69
tests/models/test_picodet.py
Executable file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_picodet():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/picodet_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/picodet_l_320_coco_lcnet"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_paddle_backend()
|
||||
model = fd.vision.detection.PicoDet(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/picodet_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.0
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-02, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("picodet_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_picodet()
|
70
tests/models/test_pp_yolox.py
Executable file
70
tests/models/test_pp_yolox.py
Executable file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
print(fd.__path__)
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_yolox():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolox_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/yolox_s_300e_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_ort_backend()
|
||||
model = fd.vision.detection.PaddleYOLOX(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/ppyolox_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.1
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("ppyolox_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_yolox()
|
69
tests/models/test_ppyolo.py
Executable file
69
tests/models/test_ppyolo.py
Executable file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_ppyolo():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyolo_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/ppyolov2_r101vd_dcn_365e_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_paddle_backend()
|
||||
model = fd.vision.detection.PPYOLO(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/ppyolo_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.0
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("ppyolo_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_ppyolo()
|
68
tests/models/test_ppyoloe.py
Executable file
68
tests/models/test_ppyoloe.py
Executable file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_ppyoloe():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ppyoloe_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/ppyoloe_crn_l_300e_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_ort_backend()
|
||||
model = fd.vision.detection.PPYOLOE(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/ppyoloe_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.0
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# with open("ppyoloe_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_ppyoloe()
|
69
tests/models/test_yolov3.py
Executable file
69
tests/models/test_yolov3.py
Executable file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_yolov3():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/yolov3_baseline.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url, "resources")
|
||||
model_path = "resources/yolov3_darknet53_270e_coco"
|
||||
|
||||
model_file = os.path.join(model_path, "model.pdmodel")
|
||||
params_file = os.path.join(model_path, "model.pdiparams")
|
||||
config_file = os.path.join(model_path, "infer_cfg.yml")
|
||||
rc.test_option.use_ort_backend()
|
||||
model = fd.vision.detection.YOLOv3(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
|
||||
# compare diff
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
for i in range(2):
|
||||
result = model.predict(im1)
|
||||
with open("resources/yolov3_baseline.pkl", "rb") as f:
|
||||
boxes, scores, label_ids = pickle.load(f)
|
||||
pred_boxes = np.array(result.boxes)
|
||||
pred_scores = np.array(result.scores)
|
||||
pred_label_ids = np.array(result.label_ids)
|
||||
|
||||
diff_boxes = np.fabs(boxes - pred_boxes)
|
||||
diff_scores = np.fabs(scores - pred_scores)
|
||||
diff_label_ids = np.fabs(label_ids - pred_label_ids)
|
||||
|
||||
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
|
||||
|
||||
score_threshold = 0.1
|
||||
assert diff_boxes[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in boxes."
|
||||
assert diff_scores[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in scores."
|
||||
assert diff_label_ids[scores > score_threshold].max(
|
||||
) < 1e-04, "There's diff in label_ids."
|
||||
|
||||
|
||||
# result = model.predict(im1)
|
||||
# with open("yolov3_baseline.pkl", "wb") as f:
|
||||
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_yolov3()
|
Reference in New Issue
Block a user