[CVCUDA] PP-OCR detector preprocessor integrate CV-CUDA (#1382)

* move manager initialized_ flag to ppcls

* update dbdetector preprocess api

* declare processor op

* ppocr detector preprocessor support cvcuda

* move cvcuda op to class member

* ppcls use manager register api

* refactor det preprocessor init api

* add set preprocessor api

* add create processor macro

* new processor call api

* ppcls preprocessor init resize on cpu

* ppocr detector preprocessor set normalize api

* revert ppcls pybind

* remove dbdetector set preprocessor

* refine dbdetector preprocessor includes

* remove mean std in py constructor

* add comments

* update comment

* Update __init__.py
This commit is contained in:
Wang Xinyu
2023-02-22 19:39:11 +08:00
committed by GitHub
parent 2f8d9c9a57
commit 91a1c72f98
24 changed files with 448 additions and 330 deletions

View File

@@ -14,7 +14,6 @@
#include "fastdeploy/vision/classification/ppcls/preprocessor.h" #include "fastdeploy/vision/classification/ppcls/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h" #include "yaml-cpp/yaml.h"
namespace fastdeploy { namespace fastdeploy {
@@ -102,13 +101,17 @@ void PaddleClasPreprocessor::DisablePermute() {
bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch, bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) { std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
for (size_t j = 0; j < processors_.size(); ++j) { for (size_t j = 0; j < processors_.size(); ++j) {
ProcLib lib = ProcLib::DEFAULT; image_batch->proc_lib = proc_lib_;
if (initial_resize_on_cpu_ && j == 0 && if (initial_resize_on_cpu_ && j == 0 &&
processors_[j]->Name().find("Resize") == 0) { processors_[j]->Name().find("Resize") == 0) {
lib = ProcLib::OPENCV; image_batch->proc_lib = ProcLib::OPENCV;
} }
if (!(*(processors_[j].get()))(image_batch, lib)) { if (!(*(processors_[j].get()))(image_batch)) {
FDERROR << "Failed to processs image in " << processors_[j]->Name() << "." FDERROR << "Failed to processs image in " << processors_[j]->Name() << "."
<< std::endl; << std::endl;
return false; return false;

View File

@@ -55,6 +55,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
private: private:
bool BuildPreprocessPipelineFromConfig(); bool BuildPreprocessPipelineFromConfig();
bool initialized_ = false;
std::vector<std::shared_ptr<Processor>> processors_; std::vector<std::shared_ptr<Processor>> processors_;
// for recording the switch of hwc2chw // for recording the switch of hwc2chw
bool disable_permute_ = false; bool disable_permute_ = false;

View File

@@ -20,9 +20,9 @@
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
bool Processor::operator()(FDMat* mat, ProcLib lib) { bool Processor::operator()(FDMat* mat) {
ProcLib target = lib; ProcLib target = mat->proc_lib;
if (lib == ProcLib::DEFAULT) { if (mat->proc_lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib; target = DefaultProcLib::default_lib;
} }
if (target == ProcLib::FLYCV) { if (target == ProcLib::FLYCV) {
@@ -52,9 +52,14 @@ bool Processor::operator()(FDMat* mat, ProcLib lib) {
return ImplByOpenCV(mat); return ImplByOpenCV(mat);
} }
bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) { bool Processor::operator()(FDMat* mat, ProcLib lib) {
ProcLib target = lib; mat->proc_lib = lib;
if (lib == ProcLib::DEFAULT) { return operator()(mat);
}
bool Processor::operator()(FDMatBatch* mat_batch) {
ProcLib target = mat_batch->proc_lib;
if (mat_batch->proc_lib == ProcLib::DEFAULT) {
target = DefaultProcLib::default_lib; target = DefaultProcLib::default_lib;
} }
if (target == ProcLib::FLYCV) { if (target == ProcLib::FLYCV) {

View File

@@ -100,10 +100,13 @@ class FASTDEPLOY_DECL Processor {
return true; return true;
} }
virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT); virtual bool operator()(FDMat* mat);
virtual bool operator()(FDMatBatch* mat_batch, // This function is for backward compatibility, will be removed in the near
ProcLib lib = ProcLib::DEFAULT); // future, please use operator()(FDMat* mat) instead and set proc_lib in mat.
virtual bool operator()(FDMat* mat, ProcLib lib);
virtual bool operator()(FDMatBatch* mat_batch);
}; };
} // namespace vision } // namespace vision

View File

@@ -14,12 +14,6 @@
#include "fastdeploy/vision/common/processors/center_crop.h" #include "fastdeploy/vision/common/processors/center_crop.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCustomCrop.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -75,9 +69,8 @@ bool CenterCrop::ImplByCvCuda(FDMat* mat) {
int offset_x = static_cast<int>((mat->Width() - width_) / 2); int offset_x = static_cast<int>((mat->Width() - width_) / 2);
int offset_y = static_cast<int>((mat->Height() - height_) / 2); int offset_y = static_cast<int>((mat->Height() - height_) / 2);
cvcuda::CustomCrop crop_op;
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_}; NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi); cvcuda_crop_op_(mat->Stream(), src_tensor, dst_tensor, crop_roi);
mat->SetTensor(mat->output_cache); mat->SetTensor(mat->output_cache);
mat->SetWidth(width_); mat->SetWidth(width_);

View File

@@ -15,6 +15,11 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCustomCrop.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -38,6 +43,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor {
private: private:
int height_; int height_;
int width_; int width_;
#ifdef ENABLE_CVCUDA
cvcuda::CustomCrop cvcuda_crop_op_;
#endif
}; };
} // namespace vision } // namespace vision

View File

@@ -31,14 +31,14 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
} }
FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess, FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess,
"[ERROR] Error occurs while creating cuda stream."); "[ERROR] Error occurs while creating cuda stream.");
DefaultProcLib::default_lib = ProcLib::CUDA; proc_lib_ = ProcLib::CUDA;
#else #else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif #endif
if (enable_cv_cuda) { if (enable_cv_cuda) {
#ifdef ENABLE_CVCUDA #ifdef ENABLE_CVCUDA
DefaultProcLib::default_lib = ProcLib::CVCUDA; proc_lib_ = ProcLib::CVCUDA;
#else #else
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA."); FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
#endif #endif
@@ -46,16 +46,11 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
} }
bool ProcessorManager::CudaUsed() { bool ProcessorManager::CudaUsed() {
return (DefaultProcLib::default_lib == ProcLib::CUDA || return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA);
DefaultProcLib::default_lib == ProcLib::CVCUDA);
} }
bool ProcessorManager::Run(std::vector<FDMat>* images, bool ProcessorManager::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) { std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) { if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." FDERROR << "The size of input images should be greater than 0."
<< std::endl; << std::endl;
@@ -70,6 +65,7 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
FDMatBatch image_batch(images); FDMatBatch image_batch(images);
image_batch.input_cache = &batch_input_cache_; image_batch.input_cache = &batch_input_cache_;
image_batch.output_cache = &batch_output_cache_; image_batch.output_cache = &batch_output_cache_;
image_batch.proc_lib = proc_lib_;
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = 0; i < images->size(); ++i) {
if (CudaUsed()) { if (CudaUsed()) {

View File

@@ -17,6 +17,7 @@
#include "fastdeploy/utils/utils.h" #include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h" #include "fastdeploy/vision/common/processors/mat.h"
#include "fastdeploy/vision/common/processors/mat_batch.h" #include "fastdeploy/vision/common/processors/mat_batch.h"
#include "fastdeploy/vision/common/processors/base.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -78,7 +79,7 @@ class FASTDEPLOY_DECL ProcessorManager {
std::vector<FDTensor>* outputs) = 0; std::vector<FDTensor>* outputs) = 0;
protected: protected:
bool initialized_ = false; ProcLib proc_lib_ = ProcLib::DEFAULT;
private: private:
#ifdef WITH_GPU #ifdef WITH_GPU

View File

@@ -145,6 +145,7 @@ struct FASTDEPLOY_DECL Mat {
ProcLib mat_type = ProcLib::OPENCV; ProcLib mat_type = ProcLib::OPENCV;
Layout layout = Layout::HWC; Layout layout = Layout::HWC;
Device device = Device::CPU; Device device = Device::CPU;
ProcLib proc_lib = ProcLib::DEFAULT;
// Create FD Mat from FD Tensor. This method only create a // Create FD Mat from FD Tensor. This method only create a
// new FD Mat with zero copy and it's data pointer is reference // new FD Mat with zero copy and it's data pointer is reference

View File

@@ -67,6 +67,7 @@ FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]); FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
(*mats)[i].SetTensor(tensor); (*mats)[i].SetTensor(tensor);
} }
mat_batch->device = Device::GPU;
return mat_batch->Tensor(); return mat_batch->Tensor();
} else { } else {
FDASSERT(false, "FDMat is on unsupported device: %d", src->device); FDASSERT(false, "FDMat is on unsupported device: %d", src->device);

View File

@@ -60,6 +60,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
ProcLib mat_type = ProcLib::OPENCV; ProcLib mat_type = ProcLib::OPENCV;
FDMatBatchLayout layout = FDMatBatchLayout::NHWC; FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
Device device = Device::CPU; Device device = Device::CPU;
ProcLib proc_lib = ProcLib::DEFAULT;
// False: the data is stored in the mats separately // False: the data is stored in the mats separately
// True: the data is stored in the fd_tensor continuously in 4 dimensions // True: the data is stored in the fd_tensor continuously in 4 dimensions

View File

@@ -85,6 +85,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
// NHWC -> NCHW // NHWC -> NCHW
std::swap(mat_batch->output_cache->shape[1], std::swap(mat_batch->output_cache->shape[1],
mat_batch->output_cache->shape[3]); mat_batch->output_cache->shape[3]);
std::swap(mat_batch->output_cache->shape[2],
mat_batch->output_cache->shape[3]);
// Copy alpha and beta to GPU // Copy alpha and beta to GPU
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32, gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,

View File

@@ -91,6 +91,60 @@ bool Pad::ImplByFlyCV(Mat* mat) {
} }
#endif #endif
#ifdef ENABLE_CVCUDA
bool Pad::ImplByCvCuda(FDMat* mat) {
if (mat->layout != Layout::HWC) {
FDERROR << "Pad: The input data must be Layout::HWC format!" << std::endl;
return false;
}
if (mat->Channels() > 4) {
FDERROR << "Pad: Only support channels <= 4." << std::endl;
return false;
}
if (mat->Channels() != value_.size()) {
FDERROR << "Pad: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels()
<< ", the size of padding values = " << value_.size() << "."
<< std::endl;
return false;
}
float4 value;
if (value_.size() == 1) {
value = make_float4(value_[0], 0.0f, 0.0f, 0.0f);
} else if (value_.size() == 2) {
value = make_float4(value_[0], value_[1], 0.0f, 0.0f);
} else if (value_.size() == 3) {
value = make_float4(value_[0], value_[1], value_[2], 0.0f);
} else {
value = make_float4(value_[0], value_[1], value_[2], value_[3]);
}
// Prepare input tensor
FDTensor* src = CreateCachedGpuInputTensor(mat);
auto src_tensor = CreateCvCudaTensorWrapData(*src);
int height = mat->Height() + top_ + bottom_;
int width = mat->Height() + left_ + right_;
// Prepare output tensor
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
"output_cache", Device::GPU);
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, top_, left_,
NVCV_BORDER_CONSTANT, value);
mat->SetTensor(mat->output_cache);
mat->SetWidth(width);
mat->SetHeight(height);
mat->device = Device::GPU;
mat->mat_type = ProcLib::CVCUDA;
return true;
}
#endif
bool Pad::Run(Mat* mat, const int& top, const int& bottom, const int& left, bool Pad::Run(Mat* mat, const int& top, const int& bottom, const int& left,
const int& right, const std::vector<float>& value, ProcLib lib) { const int& right, const std::vector<float>& value, ProcLib lib) {
auto p = Pad(top, bottom, left, right, value); auto p = Pad(top, bottom, left, right, value);

View File

@@ -15,6 +15,11 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpCopyMakeBorder.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -32,6 +37,9 @@ class FASTDEPLOY_DECL Pad : public Processor {
bool ImplByOpenCV(Mat* mat); bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat); bool ImplByFlyCV(Mat* mat);
#endif
#ifdef ENABLE_CVCUDA
bool ImplByCvCuda(FDMat* mat);
#endif #endif
std::string Name() { return "Pad"; } std::string Name() { return "Pad"; }
@@ -39,12 +47,23 @@ class FASTDEPLOY_DECL Pad : public Processor {
const int& right, const std::vector<float>& value, const int& right, const std::vector<float>& value,
ProcLib lib = ProcLib::DEFAULT); ProcLib lib = ProcLib::DEFAULT);
bool SetPaddingSize(int top, int bottom, int left, int right) {
top_ = top;
bottom_ = bottom;
left_ = left;
right_ = right;
return true;
}
private: private:
int top_; int top_;
int bottom_; int bottom_;
int left_; int left_;
int right_; int right_;
std::vector<float> value_; std::vector<float> value_;
#ifdef ENABLE_CVCUDA
cvcuda::CopyMakeBorder cvcuda_pad_op_;
#endif
}; };
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -14,12 +14,6 @@
#include "fastdeploy/vision/common/processors/resize.h" #include "fastdeploy/vision/common/processors/resize.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -152,8 +146,7 @@ bool Resize::ImplByCvCuda(FDMat* mat) {
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
// CV-CUDA Interp value is compatible with OpenCV // CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op; cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_)); NVCVInterpolationType(interp_));
mat->SetTensor(mat->output_cache); mat->SetTensor(mat->output_cache);

View File

@@ -15,6 +15,11 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -61,6 +66,9 @@ class FASTDEPLOY_DECL Resize : public Processor {
float scale_h_ = -1.0; float scale_h_ = -1.0;
int interp_ = 1; int interp_ = 1;
bool use_scale_ = false; bool use_scale_ = false;
#ifdef ENABLE_CVCUDA
cvcuda::Resize cvcuda_resize_op_;
#endif
}; };
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -14,12 +14,6 @@
#include "fastdeploy/vision/common/processors/resize_by_short.h" #include "fastdeploy/vision/common/processors/resize_by_short.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -102,8 +96,7 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
// CV-CUDA Interp value is compatible with OpenCV // CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op; cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
resize_op(mat->Stream(), src_tensor, dst_tensor,
NVCVInterpolationType(interp_)); NVCVInterpolationType(interp_));
mat->SetTensor(mat->output_cache); mat->SetTensor(mat->output_cache);
@@ -144,8 +137,7 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch); CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch);
// CV-CUDA Interp value is compatible with OpenCV // CV-CUDA Interp value is compatible with OpenCV
cvcuda::Resize resize_op; cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch,
resize_op(mat_batch->Stream(), src_batch, dst_batch,
NVCVInterpolationType(interp_)); NVCVInterpolationType(interp_));
for (size_t i = 0; i < mat_batch->mats->size(); ++i) { for (size_t i = 0; i < mat_batch->mats->size(); ++i) {

View File

@@ -15,6 +15,11 @@
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/base.h" #include "fastdeploy/vision/common/processors/base.h"
#ifdef ENABLE_CVCUDA
#include <cvcuda/OpResize.hpp>
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
#endif
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -49,6 +54,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
std::vector<int> max_hw_; std::vector<int> max_hw_;
int interp_; int interp_;
bool use_scale_; bool use_scale_;
#ifdef ENABLE_CVCUDA
cvcuda::Resize cvcuda_resize_op_;
#endif
}; };
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

25
fastdeploy/vision/ocr/ppocr/dbdetector.cc Executable file → Normal file
View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
@@ -26,11 +27,11 @@ DBDetector::DBDetector(const std::string& model_file,
const RuntimeOption& custom_option, const RuntimeOption& custom_option,
const ModelFormat& model_format) { const ModelFormat& model_format) {
if (model_format == ModelFormat::ONNX) { if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::ORT, valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::ORT, Backend::TRT};
} else { } else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE}; valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO,
Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE}; valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE}; valid_ascend_backends = {Backend::LITE};
@@ -54,7 +55,8 @@ bool DBDetector::Initialize() {
} }
std::unique_ptr<DBDetector> DBDetector::Clone() const { std::unique_ptr<DBDetector> DBDetector::Clone() const {
std::unique_ptr<DBDetector> clone_model = utils::make_unique<DBDetector>(DBDetector(*this)); std::unique_ptr<DBDetector> clone_model =
utils::make_unique<DBDetector>(DBDetector(*this));
clone_model->SetRuntime(clone_model->CloneRuntime()); clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model; return clone_model;
} }
@@ -69,14 +71,15 @@ bool DBDetector::Predict(const cv::Mat& img,
return true; return true;
} }
bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images, bool DBDetector::BatchPredict(
const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results) { std::vector<std::vector<std::array<int, 8>>>* det_results) {
std::vector<FDMat> fd_images = WrapMat(images); std::vector<FDMat> fd_images = WrapMat(images);
std::vector<std::array<int, 4>> batch_det_img_info; if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) {
FDERROR << "Failed to preprocess input image." << std::endl; FDERROR << "Failed to preprocess input image." << std::endl;
return false; return false;
} }
auto batch_det_img_info = preprocessor_.GetBatchImgInfo();
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
@@ -84,13 +87,15 @@ bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
return false; return false;
} }
if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) { if (!postprocessor_.Run(reused_output_tensors_, det_results,
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; *batch_det_img_info)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime."
<< std::endl;
return false; return false;
} }
return true; return true;
} }
} // namesapce ocr } // namespace ocr
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -13,9 +13,8 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h" #include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -45,58 +44,55 @@ std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
*ratio_w = float(resize_w) / float(w); *ratio_w = float(resize_w) / float(w);
*/ */
} }
bool OcrDetectorResizeImage(FDMat* img,
int resize_w, DBDetectorPreprocessor::DBDetectorPreprocessor() {
int resize_h, resize_op_ = std::make_shared<Resize>(-1, -1);
int max_resize_w,
int max_resize_h) {
Resize::Run(img, resize_w, resize_h);
std::vector<float> value = {0, 0, 0}; std::vector<float> value = {0, 0, 0};
Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value); pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> std = {0.229f, 0.224f, 0.225f};
bool is_scale = true;
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
}
bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h,
int max_resize_w, int max_resize_h) {
resize_op_->SetWidthAndHeight(resize_w, resize_h);
(*resize_op_)(img);
pad_op_->SetPaddingSize(0, max_resize_h - resize_h, 0,
max_resize_w - resize_w);
(*pad_op_)(img);
return true; return true;
} }
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images, bool DBDetectorPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs, std::vector<FDTensor>* outputs) {
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
int max_resize_w = 0; int max_resize_w = 0;
int max_resize_h = 0; int max_resize_h = 0;
std::vector<std::array<int, 4>>& batch_det_img_info = *batch_det_img_info_ptr; batch_det_img_info_.clear();
batch_det_img_info.clear(); batch_det_img_info_.resize(image_batch->mats->size());
batch_det_img_info.resize(images->size()); for (size_t i = 0; i < image_batch->mats->size(); ++i) {
for (size_t i = 0; i < images->size(); ++i) { FDMat* mat = &(image_batch->mats->at(i));
FDMat* mat = &(images->at(i)); batch_det_img_info_[i] = OcrDetectorGetInfo(mat, max_side_len_);
batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_); max_resize_w = std::max(max_resize_w, batch_det_img_info_[i][2]);
max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]); max_resize_h = std::max(max_resize_h, batch_det_img_info_[i][3]);
max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]);
} }
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(images->at(i)); FDMat* mat = &(image_batch->mats->at(i));
OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h); ResizeImage(mat, batch_det_img_info_[i][2], batch_det_img_info_[i][3],
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); max_resize_w, max_resize_h);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
} }
// Only have 1 output Tensor. (*normalize_permute_op_)(image_batch);
outputs->resize(1); outputs->resize(1);
// Concat all the preprocessed data to a batch tensor FDTensor* tensor = image_batch->Tensor();
std::vector<FDTensor> tensors(images->size()); (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
for (size_t i = 0; i < images->size(); ++i) { tensor->Data(), tensor->device,
(*images)[i].ShareWithTensor(&(tensors[i])); tensor->device_id);
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true; return true;
} }

View File

@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/resize.h"
#include "fastdeploy/vision/common/processors/pad.h"
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
namespace fastdeploy { namespace fastdeploy {
@@ -22,43 +25,48 @@ namespace vision {
namespace ocr { namespace ocr {
/*! @brief Preprocessor object for DBDetector serials model. /*! @brief Preprocessor object for DBDetector serials model.
*/ */
class FASTDEPLOY_DECL DBDetectorPreprocessor { class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager {
public: public:
DBDetectorPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input data list, all the elements are FDMat * \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime * \param[in] outputs The output tensors which will feed in runtime
* \param[in] batch_det_img_info_ptr The output of preprocess
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs, virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
/// Set max_side_len for the detection preprocess, default is 960 /// Set max_side_len for the detection preprocess, default is 960
void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; } void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; }
/// Get max_side_len of the detection preprocess /// Get max_side_len of the detection preprocess
int GetMaxSideLen() const { return max_side_len_; } int GetMaxSideLen() const { return max_side_len_; }
/// Set mean value for the image normalization in detection preprocess /// Set preprocess normalize parameters, please call this API to customize
void SetMean(const std::vector<float>& mean) { mean_ = mean; } /// the normalize parameters, otherwise it will use the default normalize
/// Get mean value of the image normalization in detection preprocess /// parameters.
std::vector<float> GetMean() const { return mean_; } void SetNormalize(const std::vector<float>& mean = {0.485f, 0.456f, 0.406f},
const std::vector<float>& std = {0.229f, 0.224f, 0.225f},
bool is_scale = true) {
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
}
/// Set scale value for the image normalization in detection preprocess /// Get the image info of the last batch, return a list of array
void SetScale(const std::vector<float>& scale) { scale_ = scale; } /// {image width, image height, resize width, resize height}
/// Get scale value of the image normalization in detection preprocess const std::vector<std::array<int, 4>>* GetBatchImgInfo() {
std::vector<float> GetScale() const { return scale_; } return &batch_det_img_info_;
}
/// Set is_scale for the image normalization in detection preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in detection preprocess
bool GetIsScale() const { return is_scale_; }
private: private:
bool ResizeImage(FDMat* img, int resize_w, int resize_h, int max_resize_w,
int max_resize_h);
int max_side_len_ = 960; int max_side_len_ = 960;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f}; std::vector<std::array<int, 4>> batch_det_img_info_;
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f}; std::shared_ptr<Resize> resize_op_;
bool is_scale_ = true; std::shared_ptr<Pad> pad_op_;
std::shared_ptr<NormalizeAndPermute> normalize_permute_op_;
}; };
} // namespace ocr } // namespace ocr

182
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc Executable file → Normal file
View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "fastdeploy/pybind/main.h" #include "fastdeploy/pybind/main.h"
namespace fastdeploy { namespace fastdeploy {
@@ -22,52 +23,74 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// DBDetector // DBDetector
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor") pybind11::class_<vision::ocr::DBDetectorPreprocessor>(
m, "DBDetectorPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) .def_property("max_side_len",
.def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean) &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen,
.def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale) &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen)
.def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale) .def("set_normalize",
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) { [](vision::ocr::DBDetectorPreprocessor& self,
const std::vector<float>& mean, const std::vector<float>& std,
bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def("run", [](vision::ocr::DBDetectorPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images; std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) { for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
std::vector<std::array<int, 4>> batch_det_img_info; self.Run(&images, &outputs);
self.Run(&images, &outputs, &batch_det_img_info); auto batch_det_img_info = self.GetBatchImgInfo();
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
} }
return std::make_pair(outputs, batch_det_img_info); return std::make_pair(outputs, *batch_det_img_info);
}); });
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor") pybind11::class_<vision::ocr::DBDetectorPostprocessor>(
m, "DBDetectorPostprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) .def_property("det_db_thresh",
.def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh,
.def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh)
.def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) .def_property("det_db_box_thresh",
.def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation) &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh,
&vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh)
.def_property("det_db_unclip_ratio",
&vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio,
&vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio)
.def_property("det_db_score_mode",
&vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode,
&vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode)
.def_property("use_dilation",
&vision::ocr::DBDetectorPostprocessor::GetUseDilation,
&vision::ocr::DBDetectorPostprocessor::SetUseDilation)
.def("run", [](vision::ocr::DBDetectorPostprocessor& self, .def("run",
[](vision::ocr::DBDetectorPostprocessor& self,
std::vector<FDTensor>& inputs, std::vector<FDTensor>& inputs,
const std::vector<std::array<int, 4>>& batch_det_img_info) { const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results; std::vector<std::vector<std::array<int, 8>>> results;
if (!self.Run(inputs, &results, batch_det_img_info)) { if (!self.Run(inputs, &results, batch_det_img_info)) {
throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"DBDetectorPostprocessor.");
} }
return results; return results;
}) })
.def("run", [](vision::ocr::DBDetectorPostprocessor& self, .def("run",
[](vision::ocr::DBDetectorPostprocessor& self,
std::vector<pybind11::array>& input_array, std::vector<pybind11::array>& input_array,
const std::vector<std::array<int, 4>>& batch_det_img_info) { const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results; std::vector<std::vector<std::array<int, 8>>> results;
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, batch_det_img_info)) { if (!self.Run(inputs, &results, batch_det_img_info)) {
throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"DBDetectorPostprocessor.");
} }
return results; return results;
}); });
@@ -76,16 +99,19 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor) .def_property_readonly("preprocessor",
.def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor) &vision::ocr::DBDetector::GetPreprocessor)
.def("predict", [](vision::ocr::DBDetector& self, .def_property_readonly("postprocessor",
pybind11::array& data) { &vision::ocr::DBDetector::GetPostprocessor)
.def("predict",
[](vision::ocr::DBDetector& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
std::vector<std::array<int, 8>> boxes_result; std::vector<std::array<int, 8>> boxes_result;
self.Predict(mat, &boxes_result); self.Predict(mat, &boxes_result);
return boxes_result; return boxes_result;
}) })
.def("batch_predict", [](vision::ocr::DBDetector& self, std::vector<pybind11::array>& data) { .def("batch_predict", [](vision::ocr::DBDetector& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images; std::vector<cv::Mat> images;
std::vector<std::vector<std::array<int, 8>>> det_results; std::vector<std::vector<std::array<int, 8>>> det_results;
for (size_t i = 0; i < data.size(); ++i) { for (size_t i = 0; i < data.size(); ++i) {
@@ -96,20 +122,29 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// Classifier // Classifier
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor") pybind11::class_<vision::ocr::ClassifierPreprocessor>(
m, "ClassifierPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape) .def_property("cls_image_shape",
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean) &vision::ocr::ClassifierPreprocessor::GetClsImageShape,
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale) &vision::ocr::ClassifierPreprocessor::SetClsImageShape)
.def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale) .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean,
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) { &vision::ocr::ClassifierPreprocessor::SetMean)
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale,
&vision::ocr::ClassifierPreprocessor::SetScale)
.def_property("is_scale",
&vision::ocr::ClassifierPreprocessor::GetIsScale,
&vision::ocr::ClassifierPreprocessor::SetIsScale)
.def("run", [](vision::ocr::ClassifierPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images; std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) { for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
throw std::runtime_error("Failed to preprocess the input data in ClassifierPreprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in ClassifierPreprocessor.");
} }
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -117,15 +152,21 @@ void BindPPOCRModel(pybind11::module& m) {
return outputs; return outputs;
}); });
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor") pybind11::class_<vision::ocr::ClassifierPostprocessor>(
m, "ClassifierPostprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh) .def_property("cls_thresh",
.def("run", [](vision::ocr::ClassifierPostprocessor& self, &vision::ocr::ClassifierPostprocessor::GetClsThresh,
&vision::ocr::ClassifierPostprocessor::SetClsThresh)
.def("run",
[](vision::ocr::ClassifierPostprocessor& self,
std::vector<FDTensor>& inputs) { std::vector<FDTensor>& inputs) {
std::vector<int> cls_labels; std::vector<int> cls_labels;
std::vector<float> cls_scores; std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"ClassifierPostprocessor.");
} }
return std::make_pair(cls_labels, cls_scores); return std::make_pair(cls_labels, cls_scores);
}) })
@@ -136,7 +177,9 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<int> cls_labels; std::vector<int> cls_labels;
std::vector<float> cls_scores; std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"ClassifierPostprocessor.");
} }
return std::make_pair(cls_labels, cls_scores); return std::make_pair(cls_labels, cls_scores);
}); });
@@ -145,17 +188,20 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor) .def_property_readonly("preprocessor",
.def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor) &vision::ocr::Classifier::GetPreprocessor)
.def("predict", [](vision::ocr::Classifier& self, .def_property_readonly("postprocessor",
pybind11::array& data) { &vision::ocr::Classifier::GetPostprocessor)
.def("predict",
[](vision::ocr::Classifier& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
int32_t cls_label; int32_t cls_label;
float cls_score; float cls_score;
self.Predict(mat, &cls_label, &cls_score); self.Predict(mat, &cls_label, &cls_score);
return std::make_pair(cls_label, cls_score); return std::make_pair(cls_label, cls_score);
}) })
.def("batch_predict", [](vision::ocr::Classifier& self, std::vector<pybind11::array>& data) { .def("batch_predict", [](vision::ocr::Classifier& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images; std::vector<cv::Mat> images;
std::vector<int32_t> cls_labels; std::vector<int32_t> cls_labels;
std::vector<float> cls_scores; std::vector<float> cls_scores;
@@ -167,21 +213,32 @@ void BindPPOCRModel(pybind11::module& m) {
}); });
// Recognizer // Recognizer
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor") pybind11::class_<vision::ocr::RecognizerPreprocessor>(
m, "RecognizerPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) .def_property("static_shape_infer",
.def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape) &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer,
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean) &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer)
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale) .def_property("rec_image_shape",
.def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale) &vision::ocr::RecognizerPreprocessor::GetRecImageShape,
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) { &vision::ocr::RecognizerPreprocessor::SetRecImageShape)
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean,
&vision::ocr::RecognizerPreprocessor::SetMean)
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale,
&vision::ocr::RecognizerPreprocessor::SetScale)
.def_property("is_scale",
&vision::ocr::RecognizerPreprocessor::GetIsScale,
&vision::ocr::RecognizerPreprocessor::SetIsScale)
.def("run", [](vision::ocr::RecognizerPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images; std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) { for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
} }
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) { if (!self.Run(&images, &outputs)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in RecognizerPreprocessor.");
} }
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing(); outputs[i].StopSharing();
@@ -189,14 +246,18 @@ void BindPPOCRModel(pybind11::module& m) {
return outputs; return outputs;
}); });
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor") pybind11::class_<vision::ocr::RecognizerPostprocessor>(
m, "RecognizerPostprocessor")
.def(pybind11::init<std::string>()) .def(pybind11::init<std::string>())
.def("run", [](vision::ocr::RecognizerPostprocessor& self, .def("run",
[](vision::ocr::RecognizerPostprocessor& self,
std::vector<FDTensor>& inputs) { std::vector<FDTensor>& inputs) {
std::vector<std::string> texts; std::vector<std::string> texts;
std::vector<float> rec_scores; std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"RecognizerPostprocessor.");
} }
return std::make_pair(texts, rec_scores); return std::make_pair(texts, rec_scores);
}) })
@@ -207,7 +268,9 @@ void BindPPOCRModel(pybind11::module& m) {
std::vector<std::string> texts; std::vector<std::string> texts;
std::vector<float> rec_scores; std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); throw std::runtime_error(
"Failed to preprocess the input data in "
"RecognizerPostprocessor.");
} }
return std::make_pair(texts, rec_scores); return std::make_pair(texts, rec_scores);
}); });
@@ -216,17 +279,20 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor) .def_property_readonly("preprocessor",
.def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor) &vision::ocr::Recognizer::GetPreprocessor)
.def("predict", [](vision::ocr::Recognizer& self, .def_property_readonly("postprocessor",
pybind11::array& data) { &vision::ocr::Recognizer::GetPostprocessor)
.def("predict",
[](vision::ocr::Recognizer& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
std::string text; std::string text;
float rec_score; float rec_score;
self.Predict(mat, &text, &rec_score); self.Predict(mat, &text, &rec_score);
return std::make_pair(text, rec_score); return std::make_pair(text, rec_score);
}) })
.def("batch_predict", [](vision::ocr::Recognizer& self, std::vector<pybind11::array>& data) { .def("batch_predict", [](vision::ocr::Recognizer& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images; std::vector<cv::Mat> images;
std::vector<std::string> texts; std::vector<std::string> texts;
std::vector<float> rec_scores; std::vector<float> rec_scores;

View File

@@ -46,7 +46,6 @@ class PaddleClasPreprocessor(ProcessorManager):
When the initial operator is Resize, and input image size is large, When the initial operator is Resize, and input image size is large,
maybe it's better to run resize on CPU, because the HostToDevice memcpy maybe it's better to run resize on CPU, because the HostToDevice memcpy
is time consuming. Set this True to run the initial resize on CPU. is time consuming. Set this True to run the initial resize on CPU.
:param: v: True or False :param: v: True or False
""" """
self._manager.initial_resize_on_cpu(v) self._manager.initial_resize_on_cpu(v)

View File

@@ -37,43 +37,31 @@ class DBDetectorPreprocessor:
@property @property
def max_side_len(self): def max_side_len(self):
"""Get max_side_len value.
"""
return self._preprocessor.max_side_len return self._preprocessor.max_side_len
@max_side_len.setter @max_side_len.setter
def max_side_len(self, value): def max_side_len(self, value):
"""Set max_side_len value.
:param: value: (int) max_side_len value
"""
assert isinstance( assert isinstance(
value, int), "The value to set `max_side_len` must be type of int." value, int), "The value to set `max_side_len` must be type of int."
self._preprocessor.max_side_len = value self._preprocessor.max_side_len = value
@property def set_normalize(self,
def is_scale(self): mean=[0.485, 0.456, 0.406],
return self._preprocessor.is_scale std=[0.229, 0.224, 0.225],
is_scale=True):
@is_scale.setter """Set preprocess normalize parameters, please call this API to
def is_scale(self, value): customize the normalize parameters, otherwise it will use the default
assert isinstance( normalize parameters.
value, bool), "The value to set `is_scale` must be type of bool." :param: mean: (list of float) mean values
self._preprocessor.is_scale = value :param: std: (list of float) std values
:param: is_scale: (boolean) whether to scale
@property """
def scale(self): self._preprocessor.set_normalize(mean, std, is_scale)
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
class DBDetectorPostprocessor: class DBDetectorPostprocessor:
@@ -174,6 +162,7 @@ class DBDetector(FastDeployModel):
"""Clone OCR detection model object """Clone OCR detection model object
:return: a new OCR detection model object :return: a new OCR detection model object
""" """
class DBDetectorClone(DBDetector): class DBDetectorClone(DBDetector):
def __init__(self, model): def __init__(self, model):
self._model = model self._model = model
@@ -203,18 +192,10 @@ class DBDetector(FastDeployModel):
def preprocessor(self): def preprocessor(self):
return self._model.preprocessor return self._model.preprocessor
@preprocessor.setter
def preprocessor(self, value):
self._model.preprocessor = value
@property @property
def postprocessor(self): def postprocessor(self):
return self._model.postprocessor return self._model.postprocessor
@postprocessor.setter
def postprocessor(self, value):
self._model.postprocessor = value
# Det Preprocessor Property # Det Preprocessor Property
@property @property
def max_side_len(self): def max_side_len(self):
@@ -226,36 +207,6 @@ class DBDetector(FastDeployModel):
value, int), "The value to set `max_side_len` must be type of int." value, int), "The value to set `max_side_len` must be type of int."
self._model.preprocessor.max_side_len = value self._model.preprocessor.max_side_len = value
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
# Det Ppstprocessor Property # Det Ppstprocessor Property
@property @property
def det_db_thresh(self): def det_db_thresh(self):
@@ -421,6 +372,7 @@ class Classifier(FastDeployModel):
"""Clone OCR classification model object """Clone OCR classification model object
:return: a new OCR classification model object :return: a new OCR classification model object
""" """
class ClassifierClone(Classifier): class ClassifierClone(Classifier):
def __init__(self, model): def __init__(self, model):
self._model = model self._model = model
@@ -629,6 +581,7 @@ class Recognizer(FastDeployModel):
"""Clone OCR recognition model object """Clone OCR recognition model object
:return: a new OCR recognition model object :return: a new OCR recognition model object
""" """
class RecognizerClone(Recognizer): class RecognizerClone(Recognizer):
def __init__(self, model): def __init__(self, model):
self._model = model self._model = model
@@ -743,6 +696,7 @@ class PPOCRv3(FastDeployModel):
"""Clone PPOCRv3 pipeline object """Clone PPOCRv3 pipeline object
:return: a new PPOCRv3 pipeline object :return: a new PPOCRv3 pipeline object
""" """
class PPOCRv3Clone(PPOCRv3): class PPOCRv3Clone(PPOCRv3):
def __init__(self, system): def __init__(self, system):
self.system_ = system self.system_ = system
@@ -818,6 +772,7 @@ class PPOCRv2(FastDeployModel):
"""Clone PPOCRv3 pipeline object """Clone PPOCRv3 pipeline object
:return: a new PPOCRv3 pipeline object :return: a new PPOCRv3 pipeline object
""" """
class PPOCRv2Clone(PPOCRv2): class PPOCRv2Clone(PPOCRv2):
def __init__(self, system): def __init__(self, system):
self.system_ = system self.system_ = system