diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc index ef0da9ce5..619ba87fd 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.cc +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -14,7 +14,6 @@ #include "fastdeploy/vision/classification/ppcls/preprocessor.h" -#include "fastdeploy/function/concat.h" #include "yaml-cpp/yaml.h" namespace fastdeploy { @@ -102,13 +101,17 @@ void PaddleClasPreprocessor::DisablePermute() { bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch, std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } for (size_t j = 0; j < processors_.size(); ++j) { - ProcLib lib = ProcLib::DEFAULT; + image_batch->proc_lib = proc_lib_; if (initial_resize_on_cpu_ && j == 0 && processors_[j]->Name().find("Resize") == 0) { - lib = ProcLib::OPENCV; + image_batch->proc_lib = ProcLib::OPENCV; } - if (!(*(processors_[j].get()))(image_batch, lib)) { + if (!(*(processors_[j].get()))(image_batch)) { FDERROR << "Failed to processs image in " << processors_[j]->Name() << "." << std::endl; return false; diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.h b/fastdeploy/vision/classification/ppcls/preprocessor.h index fc347fc3d..ac2e82ef1 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.h +++ b/fastdeploy/vision/classification/ppcls/preprocessor.h @@ -55,6 +55,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager { private: bool BuildPreprocessPipelineFromConfig(); + bool initialized_ = false; std::vector> processors_; // for recording the switch of hwc2chw bool disable_permute_ = false; diff --git a/fastdeploy/vision/common/processors/base.cc b/fastdeploy/vision/common/processors/base.cc index 9c4a0177e..7e34d07bf 100644 --- a/fastdeploy/vision/common/processors/base.cc +++ b/fastdeploy/vision/common/processors/base.cc @@ -20,9 +20,9 @@ namespace fastdeploy { namespace vision { -bool Processor::operator()(FDMat* mat, ProcLib lib) { - ProcLib target = lib; - if (lib == ProcLib::DEFAULT) { +bool Processor::operator()(FDMat* mat) { + ProcLib target = mat->proc_lib; + if (mat->proc_lib == ProcLib::DEFAULT) { target = DefaultProcLib::default_lib; } if (target == ProcLib::FLYCV) { @@ -52,9 +52,14 @@ bool Processor::operator()(FDMat* mat, ProcLib lib) { return ImplByOpenCV(mat); } -bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) { - ProcLib target = lib; - if (lib == ProcLib::DEFAULT) { +bool Processor::operator()(FDMat* mat, ProcLib lib) { + mat->proc_lib = lib; + return operator()(mat); +} + +bool Processor::operator()(FDMatBatch* mat_batch) { + ProcLib target = mat_batch->proc_lib; + if (mat_batch->proc_lib == ProcLib::DEFAULT) { target = DefaultProcLib::default_lib; } if (target == ProcLib::FLYCV) { diff --git a/fastdeploy/vision/common/processors/base.h b/fastdeploy/vision/common/processors/base.h index 786e88672..a1c64a2c1 100644 --- a/fastdeploy/vision/common/processors/base.h +++ b/fastdeploy/vision/common/processors/base.h @@ -100,10 +100,13 @@ class FASTDEPLOY_DECL Processor { return true; } - virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT); + virtual bool operator()(FDMat* mat); - virtual bool operator()(FDMatBatch* mat_batch, - ProcLib lib = ProcLib::DEFAULT); + // This function is for backward compatibility, will be removed in the near + // future, please use operator()(FDMat* mat) instead and set proc_lib in mat. + virtual bool operator()(FDMat* mat, ProcLib lib); + + virtual bool operator()(FDMatBatch* mat_batch); }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/center_crop.cc b/fastdeploy/vision/common/processors/center_crop.cc index 1857f7a81..f220ac376 100644 --- a/fastdeploy/vision/common/processors/center_crop.cc +++ b/fastdeploy/vision/common/processors/center_crop.cc @@ -14,12 +14,6 @@ #include "fastdeploy/vision/common/processors/center_crop.h" -#ifdef ENABLE_CVCUDA -#include - -#include "fastdeploy/vision/common/processors/cvcuda_utils.h" -#endif - namespace fastdeploy { namespace vision { @@ -75,9 +69,8 @@ bool CenterCrop::ImplByCvCuda(FDMat* mat) { int offset_x = static_cast((mat->Width() - width_) / 2); int offset_y = static_cast((mat->Height() - height_) / 2); - cvcuda::CustomCrop crop_op; NVCVRectI crop_roi = {offset_x, offset_y, width_, height_}; - crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi); + cvcuda_crop_op_(mat->Stream(), src_tensor, dst_tensor, crop_roi); mat->SetTensor(mat->output_cache); mat->SetWidth(width_); diff --git a/fastdeploy/vision/common/processors/center_crop.h b/fastdeploy/vision/common/processors/center_crop.h index 3ca3a7391..0eddde0ed 100644 --- a/fastdeploy/vision/common/processors/center_crop.h +++ b/fastdeploy/vision/common/processors/center_crop.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -38,6 +43,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor { private: int height_; int width_; +#ifdef ENABLE_CVCUDA + cvcuda::CustomCrop cvcuda_crop_op_; +#endif }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/manager.cc b/fastdeploy/vision/common/processors/manager.cc index 070354da1..2f751ab80 100644 --- a/fastdeploy/vision/common/processors/manager.cc +++ b/fastdeploy/vision/common/processors/manager.cc @@ -31,14 +31,14 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) { } FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess, "[ERROR] Error occurs while creating cuda stream."); - DefaultProcLib::default_lib = ProcLib::CUDA; + proc_lib_ = ProcLib::CUDA; #else FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); #endif if (enable_cv_cuda) { #ifdef ENABLE_CVCUDA - DefaultProcLib::default_lib = ProcLib::CVCUDA; + proc_lib_ = ProcLib::CVCUDA; #else FDASSERT(false, "FastDeploy didn't compile with CV-CUDA."); #endif @@ -46,16 +46,11 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) { } bool ProcessorManager::CudaUsed() { - return (DefaultProcLib::default_lib == ProcLib::CUDA || - DefaultProcLib::default_lib == ProcLib::CVCUDA); + return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA); } bool ProcessorManager::Run(std::vector* images, std::vector* outputs) { - if (!initialized_) { - FDERROR << "The preprocessor is not initialized." << std::endl; - return false; - } if (images->size() == 0) { FDERROR << "The size of input images should be greater than 0." << std::endl; @@ -70,6 +65,7 @@ bool ProcessorManager::Run(std::vector* images, FDMatBatch image_batch(images); image_batch.input_cache = &batch_input_cache_; image_batch.output_cache = &batch_output_cache_; + image_batch.proc_lib = proc_lib_; for (size_t i = 0; i < images->size(); ++i) { if (CudaUsed()) { diff --git a/fastdeploy/vision/common/processors/manager.h b/fastdeploy/vision/common/processors/manager.h index 48b5575c4..aa6dde56a 100644 --- a/fastdeploy/vision/common/processors/manager.h +++ b/fastdeploy/vision/common/processors/manager.h @@ -17,6 +17,7 @@ #include "fastdeploy/utils/utils.h" #include "fastdeploy/vision/common/processors/mat.h" #include "fastdeploy/vision/common/processors/mat_batch.h" +#include "fastdeploy/vision/common/processors/base.h" namespace fastdeploy { namespace vision { @@ -78,7 +79,7 @@ class FASTDEPLOY_DECL ProcessorManager { std::vector* outputs) = 0; protected: - bool initialized_ = false; + ProcLib proc_lib_ = ProcLib::DEFAULT; private: #ifdef WITH_GPU diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index 13ae76abd..85f121b90 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -145,6 +145,7 @@ struct FASTDEPLOY_DECL Mat { ProcLib mat_type = ProcLib::OPENCV; Layout layout = Layout::HWC; Device device = Device::CPU; + ProcLib proc_lib = ProcLib::DEFAULT; // Create FD Mat from FD Tensor. This method only create a // new FD Mat with zero copy and it's data pointer is reference diff --git a/fastdeploy/vision/common/processors/mat_batch.cc b/fastdeploy/vision/common/processors/mat_batch.cc index f625d6d4d..aa154f334 100644 --- a/fastdeploy/vision/common/processors/mat_batch.cc +++ b/fastdeploy/vision/common/processors/mat_batch.cc @@ -67,6 +67,7 @@ FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) { FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]); (*mats)[i].SetTensor(tensor); } + mat_batch->device = Device::GPU; return mat_batch->Tensor(); } else { FDASSERT(false, "FDMat is on unsupported device: %d", src->device); diff --git a/fastdeploy/vision/common/processors/mat_batch.h b/fastdeploy/vision/common/processors/mat_batch.h index 090d8bb59..9d876a911 100644 --- a/fastdeploy/vision/common/processors/mat_batch.h +++ b/fastdeploy/vision/common/processors/mat_batch.h @@ -60,6 +60,7 @@ struct FASTDEPLOY_DECL FDMatBatch { ProcLib mat_type = ProcLib::OPENCV; FDMatBatchLayout layout = FDMatBatchLayout::NHWC; Device device = Device::CPU; + ProcLib proc_lib = ProcLib::DEFAULT; // False: the data is stored in the mats separately // True: the data is stored in the fd_tensor continuously in 4 dimensions diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.cu b/fastdeploy/vision/common/processors/normalize_and_permute.cu index 7f6320ba4..da3f4ffb1 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.cu +++ b/fastdeploy/vision/common/processors/normalize_and_permute.cu @@ -85,6 +85,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) { // NHWC -> NCHW std::swap(mat_batch->output_cache->shape[1], mat_batch->output_cache->shape[3]); + std::swap(mat_batch->output_cache->shape[2], + mat_batch->output_cache->shape[3]); // Copy alpha and beta to GPU gpu_alpha_.Resize({1, 1, static_cast(alpha_.size())}, FDDataType::FP32, diff --git a/fastdeploy/vision/common/processors/pad.cc b/fastdeploy/vision/common/processors/pad.cc index 278e8d4b7..2db1fba20 100644 --- a/fastdeploy/vision/common/processors/pad.cc +++ b/fastdeploy/vision/common/processors/pad.cc @@ -91,6 +91,60 @@ bool Pad::ImplByFlyCV(Mat* mat) { } #endif +#ifdef ENABLE_CVCUDA +bool Pad::ImplByCvCuda(FDMat* mat) { + if (mat->layout != Layout::HWC) { + FDERROR << "Pad: The input data must be Layout::HWC format!" << std::endl; + return false; + } + if (mat->Channels() > 4) { + FDERROR << "Pad: Only support channels <= 4." << std::endl; + return false; + } + if (mat->Channels() != value_.size()) { + FDERROR << "Pad: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() + << ", the size of padding values = " << value_.size() << "." + << std::endl; + return false; + } + + float4 value; + if (value_.size() == 1) { + value = make_float4(value_[0], 0.0f, 0.0f, 0.0f); + } else if (value_.size() == 2) { + value = make_float4(value_[0], value_[1], 0.0f, 0.0f); + } else if (value_.size() == 3) { + value = make_float4(value_[0], value_[1], value_[2], 0.0f); + } else { + value = make_float4(value_[0], value_[1], value_[2], value_[3]); + } + + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat); + auto src_tensor = CreateCvCudaTensorWrapData(*src); + + int height = mat->Height() + top_ + bottom_; + int width = mat->Height() + left_ + right_; + + // Prepare output tensor + mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); + + cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, top_, left_, + NVCV_BORDER_CONSTANT, value); + + mat->SetTensor(mat->output_cache); + mat->SetWidth(width); + mat->SetHeight(height); + mat->device = Device::GPU; + mat->mat_type = ProcLib::CVCUDA; + return true; +} +#endif + bool Pad::Run(Mat* mat, const int& top, const int& bottom, const int& left, const int& right, const std::vector& value, ProcLib lib) { auto p = Pad(top, bottom, left, right, value); diff --git a/fastdeploy/vision/common/processors/pad.h b/fastdeploy/vision/common/processors/pad.h index 661632e77..5d025c720 100644 --- a/fastdeploy/vision/common/processors/pad.h +++ b/fastdeploy/vision/common/processors/pad.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -32,6 +37,9 @@ class FASTDEPLOY_DECL Pad : public Processor { bool ImplByOpenCV(Mat* mat); #ifdef ENABLE_FLYCV bool ImplByFlyCV(Mat* mat); +#endif +#ifdef ENABLE_CVCUDA + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "Pad"; } @@ -39,12 +47,23 @@ class FASTDEPLOY_DECL Pad : public Processor { const int& right, const std::vector& value, ProcLib lib = ProcLib::DEFAULT); + bool SetPaddingSize(int top, int bottom, int left, int right) { + top_ = top; + bottom_ = bottom; + left_ = left; + right_ = right; + return true; + } + private: int top_; int bottom_; int left_; int right_; std::vector value_; +#ifdef ENABLE_CVCUDA + cvcuda::CopyMakeBorder cvcuda_pad_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/resize.cc b/fastdeploy/vision/common/processors/resize.cc index 0de6ddfc7..806eab643 100644 --- a/fastdeploy/vision/common/processors/resize.cc +++ b/fastdeploy/vision/common/processors/resize.cc @@ -14,12 +14,6 @@ #include "fastdeploy/vision/common/processors/resize.h" -#ifdef ENABLE_CVCUDA -#include - -#include "fastdeploy/vision/common/processors/cvcuda_utils.h" -#endif - namespace fastdeploy { namespace vision { @@ -152,9 +146,8 @@ bool Resize::ImplByCvCuda(FDMat* mat) { auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); // CV-CUDA Interp value is compatible with OpenCV - cvcuda::Resize resize_op; - resize_op(mat->Stream(), src_tensor, dst_tensor, - NVCVInterpolationType(interp_)); + cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor, + NVCVInterpolationType(interp_)); mat->SetTensor(mat->output_cache); mat->SetWidth(width_); diff --git a/fastdeploy/vision/common/processors/resize.h b/fastdeploy/vision/common/processors/resize.h index 2b4f88a35..607287d80 100644 --- a/fastdeploy/vision/common/processors/resize.h +++ b/fastdeploy/vision/common/processors/resize.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -61,6 +66,9 @@ class FASTDEPLOY_DECL Resize : public Processor { float scale_h_ = -1.0; int interp_ = 1; bool use_scale_ = false; +#ifdef ENABLE_CVCUDA + cvcuda::Resize cvcuda_resize_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/resize_by_short.cc b/fastdeploy/vision/common/processors/resize_by_short.cc index 535652fc7..7fe644e0d 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.cc +++ b/fastdeploy/vision/common/processors/resize_by_short.cc @@ -14,12 +14,6 @@ #include "fastdeploy/vision/common/processors/resize_by_short.h" -#ifdef ENABLE_CVCUDA -#include - -#include "fastdeploy/vision/common/processors/cvcuda_utils.h" -#endif - namespace fastdeploy { namespace vision { @@ -102,9 +96,8 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) { auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); // CV-CUDA Interp value is compatible with OpenCV - cvcuda::Resize resize_op; - resize_op(mat->Stream(), src_tensor, dst_tensor, - NVCVInterpolationType(interp_)); + cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor, + NVCVInterpolationType(interp_)); mat->SetTensor(mat->output_cache); mat->SetWidth(width); @@ -144,9 +137,8 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) { CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch); // CV-CUDA Interp value is compatible with OpenCV - cvcuda::Resize resize_op; - resize_op(mat_batch->Stream(), src_batch, dst_batch, - NVCVInterpolationType(interp_)); + cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch, + NVCVInterpolationType(interp_)); for (size_t i = 0; i < mat_batch->mats->size(); ++i) { FDMat* mat = &(*(mat_batch->mats))[i]; diff --git a/fastdeploy/vision/common/processors/resize_by_short.h b/fastdeploy/vision/common/processors/resize_by_short.h index 99078c708..08bec6438 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.h +++ b/fastdeploy/vision/common/processors/resize_by_short.h @@ -15,6 +15,11 @@ #pragma once #include "fastdeploy/vision/common/processors/base.h" +#ifdef ENABLE_CVCUDA +#include + +#include "fastdeploy/vision/common/processors/cvcuda_utils.h" +#endif namespace fastdeploy { namespace vision { @@ -49,6 +54,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor { std::vector max_hw_; int interp_; bool use_scale_; +#ifdef ENABLE_CVCUDA + cvcuda::Resize cvcuda_resize_op_; +#endif }; } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.cc b/fastdeploy/vision/ocr/ppocr/dbdetector.cc old mode 100755 new mode 100644 index cd07cc262..7dd0ac84a --- a/fastdeploy/vision/ocr/ppocr/dbdetector.cc +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "fastdeploy/vision/ocr/ppocr/dbdetector.h" + #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" @@ -26,11 +27,11 @@ DBDetector::DBDetector(const std::string& model_file, const RuntimeOption& custom_option, const ModelFormat& model_format) { if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::ORT, - Backend::OPENVINO}; - valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; } else { - valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE}; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, + Backend::LITE}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_kunlunxin_backends = {Backend::LITE}; valid_ascend_backends = {Backend::LITE}; @@ -54,7 +55,8 @@ bool DBDetector::Initialize() { } std::unique_ptr DBDetector::Clone() const { - std::unique_ptr clone_model = utils::make_unique(DBDetector(*this)); + std::unique_ptr clone_model = + utils::make_unique(DBDetector(*this)); clone_model->SetRuntime(clone_model->CloneRuntime()); return clone_model; } @@ -69,14 +71,15 @@ bool DBDetector::Predict(const cv::Mat& img, return true; } -bool DBDetector::BatchPredict(const std::vector& images, - std::vector>>* det_results) { +bool DBDetector::BatchPredict( + const std::vector& images, + std::vector>>* det_results) { std::vector fd_images = WrapMat(images); - std::vector> batch_det_img_info; - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { FDERROR << "Failed to preprocess input image." << std::endl; return false; } + auto batch_det_img_info = preprocessor_.GetBatchImgInfo(); reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { @@ -84,13 +87,15 @@ bool DBDetector::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) { - FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, det_results, + *batch_det_img_info)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." + << std::endl; return false; } return true; } -} // namesapce ocr +} // namespace ocr } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc index 28b7e47af..69687d5cd 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc @@ -13,9 +13,8 @@ // limitations under the License. #include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h" -#include "fastdeploy/utils/perf.h" + #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" -#include "fastdeploy/function/concat.h" namespace fastdeploy { namespace vision { @@ -39,64 +38,61 @@ std::array OcrDetectorGetInfo(FDMat* img, int max_size_len) { resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32); resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32); - return {w,h,resize_w,resize_h}; + return {w, h, resize_w, resize_h}; /* - *ratio_h = float(resize_h) / float(h); - *ratio_w = float(resize_w) / float(w); - */ + *ratio_h = float(resize_h) / float(h); + *ratio_w = float(resize_w) / float(w); + */ } -bool OcrDetectorResizeImage(FDMat* img, - int resize_w, - int resize_h, - int max_resize_w, - int max_resize_h) { - Resize::Run(img, resize_w, resize_h); + +DBDetectorPreprocessor::DBDetectorPreprocessor() { + resize_op_ = std::make_shared(-1, -1); + std::vector value = {0, 0, 0}; - Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value); + pad_op_ = std::make_shared(0, 0, 0, 0, value); + + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector std = {0.229f, 0.224f, 0.225f}; + bool is_scale = true; + normalize_permute_op_ = + std::make_shared(mean, std, is_scale); +} + +bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h, + int max_resize_w, int max_resize_h) { + resize_op_->SetWidthAndHeight(resize_w, resize_h); + (*resize_op_)(img); + + pad_op_->SetPaddingSize(0, max_resize_h - resize_h, 0, + max_resize_w - resize_w); + (*pad_op_)(img); return true; } -bool DBDetectorPreprocessor::Run(std::vector* images, - std::vector* outputs, - std::vector>* batch_det_img_info_ptr) { - if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." << std::endl; - return false; - } +bool DBDetectorPreprocessor::Apply(FDMatBatch* image_batch, + std::vector* outputs) { int max_resize_w = 0; int max_resize_h = 0; - std::vector>& batch_det_img_info = *batch_det_img_info_ptr; - batch_det_img_info.clear(); - batch_det_img_info.resize(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - FDMat* mat = &(images->at(i)); - batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_); - max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]); - max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]); + batch_det_img_info_.clear(); + batch_det_img_info_.resize(image_batch->mats->size()); + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + batch_det_img_info_[i] = OcrDetectorGetInfo(mat, max_side_len_); + max_resize_w = std::max(max_resize_w, batch_det_img_info_[i][2]); + max_resize_h = std::max(max_resize_h, batch_det_img_info_[i][3]); } - for (size_t i = 0; i < images->size(); ++i) { - FDMat* mat = &(images->at(i)); - OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h); - NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); - /* - Normalize::Run(mat, mean_, scale_, is_scale_); - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - */ + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + ResizeImage(mat, batch_det_img_info_[i][2], batch_det_img_info_[i][3], + max_resize_w, max_resize_h); } - // Only have 1 output Tensor. + (*normalize_permute_op_)(image_batch); + outputs->resize(1); - // Concat all the preprocessed data to a batch tensor - std::vector tensors(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - (*images)[i].ShareWithTensor(&(tensors[i])); - tensors[i].ExpandDim(0); - } - if (tensors.size() == 1) { - (*outputs)[0] = std::move(tensors[0]); - } else { - function::Concat(tensors, &((*outputs)[0]), 0); - } + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); return true; } diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h index 552d0628a..fd7b77de1 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once -#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/processors/manager.h" +#include "fastdeploy/vision/common/processors/resize.h" +#include "fastdeploy/vision/common/processors/pad.h" +#include "fastdeploy/vision/common/processors/normalize_and_permute.h" #include "fastdeploy/vision/common/result.h" namespace fastdeploy { @@ -22,43 +25,48 @@ namespace vision { namespace ocr { /*! @brief Preprocessor object for DBDetector serials model. */ -class FASTDEPLOY_DECL DBDetectorPreprocessor { +class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager { public: + DBDetectorPreprocessor(); + /** \brief Process the input image and prepare input tensors for runtime * - * \param[in] images The input data list, all the elements are FDMat + * \param[in] image_batch The input image batch * \param[in] outputs The output tensors which will feed in runtime - * \param[in] batch_det_img_info_ptr The output of preprocess * \return true if the preprocess successed, otherwise false */ - bool Run(std::vector* images, std::vector* outputs, - std::vector>* batch_det_img_info_ptr); + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs); /// Set max_side_len for the detection preprocess, default is 960 void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; } + /// Get max_side_len of the detection preprocess int GetMaxSideLen() const { return max_side_len_; } - /// Set mean value for the image normalization in detection preprocess - void SetMean(const std::vector& mean) { mean_ = mean; } - /// Get mean value of the image normalization in detection preprocess - std::vector GetMean() const { return mean_; } + /// Set preprocess normalize parameters, please call this API to customize + /// the normalize parameters, otherwise it will use the default normalize + /// parameters. + void SetNormalize(const std::vector& mean = {0.485f, 0.456f, 0.406f}, + const std::vector& std = {0.229f, 0.224f, 0.225f}, + bool is_scale = true) { + normalize_permute_op_ = + std::make_shared(mean, std, is_scale); + } - /// Set scale value for the image normalization in detection preprocess - void SetScale(const std::vector& scale) { scale_ = scale; } - /// Get scale value of the image normalization in detection preprocess - std::vector GetScale() const { return scale_; } - - /// Set is_scale for the image normalization in detection preprocess - void SetIsScale(bool is_scale) { is_scale_ = is_scale; } - /// Get is_scale of the image normalization in detection preprocess - bool GetIsScale() const { return is_scale_; } + /// Get the image info of the last batch, return a list of array + /// {image width, image height, resize width, resize height} + const std::vector>* GetBatchImgInfo() { + return &batch_det_img_info_; + } private: + bool ResizeImage(FDMat* img, int resize_w, int resize_h, int max_resize_w, + int max_resize_h); int max_side_len_ = 960; - std::vector mean_ = {0.485f, 0.456f, 0.406f}; - std::vector scale_ = {0.229f, 0.224f, 0.225f}; - bool is_scale_ = true; + std::vector> batch_det_img_info_; + std::shared_ptr resize_op_; + std::shared_ptr pad_op_; + std::shared_ptr normalize_permute_op_; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc old mode 100755 new mode 100644 index 2bcb697a8..aa77542af --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -12,80 +12,106 @@ // See the License for the specific language governing permissions and // limitations under the License. #include + #include "fastdeploy/pybind/main.h" namespace fastdeploy { void BindPPOCRModel(pybind11::module& m) { m.def("sort_boxes", [](std::vector>& boxes) { - vision::ocr::SortBoxes(&boxes); - return boxes; + vision::ocr::SortBoxes(&boxes); + return boxes; }); - + // DBDetector - pybind11::class_(m, "DBDetectorPreprocessor") + pybind11::class_( + m, "DBDetectorPreprocessor") .def(pybind11::init<>()) - .def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) - .def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean) - .def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale) - .def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector& im_list) { + .def_property("max_side_len", + &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, + &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) + .def("set_normalize", + [](vision::ocr::DBDetectorPreprocessor& self, + const std::vector& mean, const std::vector& std, + bool is_scale) { self.SetNormalize(mean, std, is_scale); }) + .def("run", [](vision::ocr::DBDetectorPreprocessor& self, + std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); } std::vector outputs; - std::vector> batch_det_img_info; - self.Run(&images, &outputs, &batch_det_img_info); - for(size_t i = 0; i< outputs.size(); ++i){ + self.Run(&images, &outputs); + auto batch_det_img_info = self.GetBatchImgInfo(); + for (size_t i = 0; i < outputs.size(); ++i) { outputs[i].StopSharing(); } - return std::make_pair(outputs, batch_det_img_info); + return std::make_pair(outputs, *batch_det_img_info); }); - pybind11::class_(m, "DBDetectorPostprocessor") + pybind11::class_( + m, "DBDetectorPostprocessor") .def(pybind11::init<>()) - .def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) - .def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) - .def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) - .def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) - .def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation) + .def_property("det_db_thresh", + &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, + &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) + .def_property("det_db_box_thresh", + &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, + &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) + .def_property("det_db_unclip_ratio", + &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, + &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) + .def_property("det_db_score_mode", + &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, + &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) + .def_property("use_dilation", + &vision::ocr::DBDetectorPostprocessor::GetUseDilation, + &vision::ocr::DBDetectorPostprocessor::SetUseDilation) - .def("run", [](vision::ocr::DBDetectorPostprocessor& self, - std::vector& inputs, - const std::vector>& batch_det_img_info) { - std::vector>> results; + .def("run", + [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& inputs, + const std::vector>& batch_det_img_info) { + std::vector>> results; - if (!self.Run(inputs, &results, batch_det_img_info)) { - throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); - } - return results; - }) - .def("run", [](vision::ocr::DBDetectorPostprocessor& self, - std::vector& input_array, - const std::vector>& batch_det_img_info) { - std::vector>> results; - std::vector inputs; - PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); - if (!self.Run(inputs, &results, batch_det_img_info)) { - throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); - } - return results; - }); + if (!self.Run(inputs, &results, batch_det_img_info)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "DBDetectorPostprocessor."); + } + return results; + }) + .def("run", + [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& input_array, + const std::vector>& batch_det_img_info) { + std::vector>> results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, batch_det_img_info)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "DBDetectorPostprocessor."); + } + return results; + }); pybind11::class_(m, "DBDetector") .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor) - .def("predict", [](vision::ocr::DBDetector& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - std::vector> boxes_result; - self.Predict(mat, &boxes_result); - return boxes_result; - }) - .def("batch_predict", [](vision::ocr::DBDetector& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::DBDetector::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::DBDetector::GetPostprocessor) + .def("predict", + [](vision::ocr::DBDetector& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + std::vector> boxes_result; + self.Predict(mat, &boxes_result); + return boxes_result; + }) + .def("batch_predict", [](vision::ocr::DBDetector& self, + std::vector& data) { std::vector images; std::vector>> det_results; for (size_t i = 0; i < data.size(); ++i) { @@ -96,39 +122,54 @@ void BindPPOCRModel(pybind11::module& m) { }); // Classifier - pybind11::class_(m, "ClassifierPreprocessor") + pybind11::class_( + m, "ClassifierPreprocessor") .def(pybind11::init<>()) - .def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape) - .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean) - .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale) - .def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector& im_list) { + .def_property("cls_image_shape", + &vision::ocr::ClassifierPreprocessor::GetClsImageShape, + &vision::ocr::ClassifierPreprocessor::SetClsImageShape) + .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, + &vision::ocr::ClassifierPreprocessor::SetMean) + .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, + &vision::ocr::ClassifierPreprocessor::SetScale) + .def_property("is_scale", + &vision::ocr::ClassifierPreprocessor::GetIsScale, + &vision::ocr::ClassifierPreprocessor::SetIsScale) + .def("run", [](vision::ocr::ClassifierPreprocessor& self, + std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); } std::vector outputs; if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPreprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in ClassifierPreprocessor."); } - for(size_t i = 0; i< outputs.size(); ++i){ + for (size_t i = 0; i < outputs.size(); ++i) { outputs[i].StopSharing(); } return outputs; }); - pybind11::class_(m, "ClassifierPostprocessor") + pybind11::class_( + m, "ClassifierPostprocessor") .def(pybind11::init<>()) - .def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh) - .def("run", [](vision::ocr::ClassifierPostprocessor& self, - std::vector& inputs) { - std::vector cls_labels; - std::vector cls_scores; - if (!self.Run(inputs, &cls_labels, &cls_scores)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); - } - return std::make_pair(cls_labels,cls_scores); - }) + .def_property("cls_thresh", + &vision::ocr::ClassifierPostprocessor::GetClsThresh, + &vision::ocr::ClassifierPostprocessor::SetClsThresh) + .def("run", + [](vision::ocr::ClassifierPostprocessor& self, + std::vector& inputs) { + std::vector cls_labels; + std::vector cls_scores; + if (!self.Run(inputs, &cls_labels, &cls_scores)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "ClassifierPostprocessor."); + } + return std::make_pair(cls_labels, cls_scores); + }) .def("run", [](vision::ocr::ClassifierPostprocessor& self, std::vector& input_array) { std::vector inputs; @@ -136,26 +177,31 @@ void BindPPOCRModel(pybind11::module& m) { std::vector cls_labels; std::vector cls_scores; if (!self.Run(inputs, &cls_labels, &cls_scores)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in " + "ClassifierPostprocessor."); } - return std::make_pair(cls_labels,cls_scores); + return std::make_pair(cls_labels, cls_scores); }); - + pybind11::class_(m, "Classifier") .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor) - .def("predict", [](vision::ocr::Classifier& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - int32_t cls_label; - float cls_score; - self.Predict(mat, &cls_label, &cls_score); - return std::make_pair(cls_label, cls_score); - }) - .def("batch_predict", [](vision::ocr::Classifier& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::Classifier::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::Classifier::GetPostprocessor) + .def("predict", + [](vision::ocr::Classifier& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + int32_t cls_label; + float cls_score; + self.Predict(mat, &cls_label, &cls_score); + return std::make_pair(cls_label, cls_score); + }) + .def("batch_predict", [](vision::ocr::Classifier& self, + std::vector& data) { std::vector images; std::vector cls_labels; std::vector cls_scores; @@ -167,39 +213,54 @@ void BindPPOCRModel(pybind11::module& m) { }); // Recognizer - pybind11::class_(m, "RecognizerPreprocessor") - .def(pybind11::init<>()) - .def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) - .def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape) - .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean) - .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale) - .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { - std::vector images; - for (size_t i = 0; i < im_list.size(); ++i) { - images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); - } - std::vector outputs; - if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor."); - } - for(size_t i = 0; i< outputs.size(); ++i){ - outputs[i].StopSharing(); - } - return outputs; - }); - - pybind11::class_(m, "RecognizerPostprocessor") - .def(pybind11::init()) - .def("run", [](vision::ocr::RecognizerPostprocessor& self, - std::vector& inputs) { - std::vector texts; - std::vector rec_scores; - if (!self.Run(inputs, &texts, &rec_scores)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); + pybind11::class_( + m, "RecognizerPreprocessor") + .def(pybind11::init<>()) + .def_property("static_shape_infer", + &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, + &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) + .def_property("rec_image_shape", + &vision::ocr::RecognizerPreprocessor::GetRecImageShape, + &vision::ocr::RecognizerPreprocessor::SetRecImageShape) + .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, + &vision::ocr::RecognizerPreprocessor::SetMean) + .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, + &vision::ocr::RecognizerPreprocessor::SetScale) + .def_property("is_scale", + &vision::ocr::RecognizerPreprocessor::GetIsScale, + &vision::ocr::RecognizerPreprocessor::SetIsScale) + .def("run", [](vision::ocr::RecognizerPreprocessor& self, + std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); } - return std::make_pair(texts, rec_scores); - }) + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error( + "Failed to preprocess the input data in RecognizerPreprocessor."); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return outputs; + }); + + pybind11::class_( + m, "RecognizerPostprocessor") + .def(pybind11::init()) + .def("run", + [](vision::ocr::RecognizerPostprocessor& self, + std::vector& inputs) { + std::vector texts; + std::vector rec_scores; + if (!self.Run(inputs, &texts, &rec_scores)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "RecognizerPostprocessor."); + } + return std::make_pair(texts, rec_scores); + }) .def("run", [](vision::ocr::RecognizerPostprocessor& self, std::vector& input_array) { std::vector inputs; @@ -207,7 +268,9 @@ void BindPPOCRModel(pybind11::module& m) { std::vector texts; std::vector rec_scores; if (!self.Run(inputs, &texts, &rec_scores)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in " + "RecognizerPostprocessor."); } return std::make_pair(texts, rec_scores); }); @@ -216,17 +279,20 @@ void BindPPOCRModel(pybind11::module& m) { .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor) - .def("predict", [](vision::ocr::Recognizer& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - std::string text; - float rec_score; - self.Predict(mat, &text, &rec_score); - return std::make_pair(text, rec_score); - }) - .def("batch_predict", [](vision::ocr::Recognizer& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::Recognizer::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::Recognizer::GetPostprocessor) + .def("predict", + [](vision::ocr::Recognizer& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + std::string text; + float rec_score; + self.Predict(mat, &text, &rec_score); + return std::make_pair(text, rec_score); + }) + .def("batch_predict", [](vision::ocr::Recognizer& self, + std::vector& data) { std::vector images; std::vector texts; std::vector rec_scores; diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index 7215bcfbc..e873a5256 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -46,7 +46,6 @@ class PaddleClasPreprocessor(ProcessorManager): When the initial operator is Resize, and input image size is large, maybe it's better to run resize on CPU, because the HostToDevice memcpy is time consuming. Set this True to run the initial resize on CPU. - :param: v: True or False """ self._manager.initial_resize_on_cpu(v) diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index 842532301..e19fb686e 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -37,43 +37,31 @@ class DBDetectorPreprocessor: @property def max_side_len(self): + """Get max_side_len value. + """ return self._preprocessor.max_side_len @max_side_len.setter def max_side_len(self, value): + """Set max_side_len value. + :param: value: (int) max_side_len value + """ assert isinstance( value, int), "The value to set `max_side_len` must be type of int." self._preprocessor.max_side_len = value - @property - def is_scale(self): - return self._preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._preprocessor.is_scale = value - - @property - def scale(self): - return self._preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._preprocessor.scale = value - - @property - def mean(self): - return self._preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._preprocessor.mean = value + def set_normalize(self, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + is_scale=True): + """Set preprocess normalize parameters, please call this API to + customize the normalize parameters, otherwise it will use the default + normalize parameters. + :param: mean: (list of float) mean values + :param: std: (list of float) std values + :param: is_scale: (boolean) whether to scale + """ + self._preprocessor.set_normalize(mean, std, is_scale) class DBDetectorPostprocessor: @@ -174,6 +162,7 @@ class DBDetector(FastDeployModel): """Clone OCR detection model object :return: a new OCR detection model object """ + class DBDetectorClone(DBDetector): def __init__(self, model): self._model = model @@ -203,18 +192,10 @@ class DBDetector(FastDeployModel): def preprocessor(self): return self._model.preprocessor - @preprocessor.setter - def preprocessor(self, value): - self._model.preprocessor = value - @property def postprocessor(self): return self._model.postprocessor - @postprocessor.setter - def postprocessor(self, value): - self._model.postprocessor = value - # Det Preprocessor Property @property def max_side_len(self): @@ -226,36 +207,6 @@ class DBDetector(FastDeployModel): value, int), "The value to set `max_side_len` must be type of int." self._model.preprocessor.max_side_len = value - @property - def is_scale(self): - return self._model.preprocessor.is_scale - - @is_scale.setter - def is_scale(self, value): - assert isinstance( - value, bool), "The value to set `is_scale` must be type of bool." - self._model.preprocessor.is_scale = value - - @property - def scale(self): - return self._model.preprocessor.scale - - @scale.setter - def scale(self, value): - assert isinstance( - value, list), "The value to set `scale` must be type of list." - self._model.preprocessor.scale = value - - @property - def mean(self): - return self._model.preprocessor.mean - - @mean.setter - def mean(self, value): - assert isinstance( - value, list), "The value to set `mean` must be type of list." - self._model.preprocessor.mean = value - # Det Ppstprocessor Property @property def det_db_thresh(self): @@ -421,6 +372,7 @@ class Classifier(FastDeployModel): """Clone OCR classification model object :return: a new OCR classification model object """ + class ClassifierClone(Classifier): def __init__(self, model): self._model = model @@ -629,6 +581,7 @@ class Recognizer(FastDeployModel): """Clone OCR recognition model object :return: a new OCR recognition model object """ + class RecognizerClone(Recognizer): def __init__(self, model): self._model = model @@ -734,7 +687,7 @@ class PPOCRv3(FastDeployModel): assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None." if cls_model is None: self.system_ = C.vision.ocr.PPOCRv3(det_model._model, - rec_model._model) + rec_model._model) else: self.system_ = C.vision.ocr.PPOCRv3( det_model._model, cls_model._model, rec_model._model) @@ -743,6 +696,7 @@ class PPOCRv3(FastDeployModel): """Clone PPOCRv3 pipeline object :return: a new PPOCRv3 pipeline object """ + class PPOCRv3Clone(PPOCRv3): def __init__(self, system): self.system_ = system @@ -809,7 +763,7 @@ class PPOCRv2(FastDeployModel): assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None." if cls_model is None: self.system_ = C.vision.ocr.PPOCRv2(det_model._model, - rec_model._model) + rec_model._model) else: self.system_ = C.vision.ocr.PPOCRv2( det_model._model, cls_model._model, rec_model._model) @@ -818,6 +772,7 @@ class PPOCRv2(FastDeployModel): """Clone PPOCRv3 pipeline object :return: a new PPOCRv3 pipeline object """ + class PPOCRv2Clone(PPOCRv2): def __init__(self, system): self.system_ = system