mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[CVCUDA] PP-OCR detector preprocessor integrate CV-CUDA (#1382)
* move manager initialized_ flag to ppcls * update dbdetector preprocess api * declare processor op * ppocr detector preprocessor support cvcuda * move cvcuda op to class member * ppcls use manager register api * refactor det preprocessor init api * add set preprocessor api * add create processor macro * new processor call api * ppcls preprocessor init resize on cpu * ppocr detector preprocessor set normalize api * revert ppcls pybind * remove dbdetector set preprocessor * refine dbdetector preprocessor includes * remove mean std in py constructor * add comments * update comment * Update __init__.py
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
#include "fastdeploy/vision/classification/ppcls/preprocessor.h"
|
||||
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -102,13 +101,17 @@ void PaddleClasPreprocessor::DisablePermute() {
|
||||
|
||||
bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
ProcLib lib = ProcLib::DEFAULT;
|
||||
image_batch->proc_lib = proc_lib_;
|
||||
if (initial_resize_on_cpu_ && j == 0 &&
|
||||
processors_[j]->Name().find("Resize") == 0) {
|
||||
lib = ProcLib::OPENCV;
|
||||
image_batch->proc_lib = ProcLib::OPENCV;
|
||||
}
|
||||
if (!(*(processors_[j].get()))(image_batch, lib)) {
|
||||
if (!(*(processors_[j].get()))(image_batch)) {
|
||||
FDERROR << "Failed to processs image in " << processors_[j]->Name() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
|
@@ -55,6 +55,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
||||
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig();
|
||||
bool initialized_ = false;
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
// for recording the switch of hwc2chw
|
||||
bool disable_permute_ = false;
|
||||
|
@@ -20,9 +20,9 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool Processor::operator()(FDMat* mat, ProcLib lib) {
|
||||
ProcLib target = lib;
|
||||
if (lib == ProcLib::DEFAULT) {
|
||||
bool Processor::operator()(FDMat* mat) {
|
||||
ProcLib target = mat->proc_lib;
|
||||
if (mat->proc_lib == ProcLib::DEFAULT) {
|
||||
target = DefaultProcLib::default_lib;
|
||||
}
|
||||
if (target == ProcLib::FLYCV) {
|
||||
@@ -52,9 +52,14 @@ bool Processor::operator()(FDMat* mat, ProcLib lib) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) {
|
||||
ProcLib target = lib;
|
||||
if (lib == ProcLib::DEFAULT) {
|
||||
bool Processor::operator()(FDMat* mat, ProcLib lib) {
|
||||
mat->proc_lib = lib;
|
||||
return operator()(mat);
|
||||
}
|
||||
|
||||
bool Processor::operator()(FDMatBatch* mat_batch) {
|
||||
ProcLib target = mat_batch->proc_lib;
|
||||
if (mat_batch->proc_lib == ProcLib::DEFAULT) {
|
||||
target = DefaultProcLib::default_lib;
|
||||
}
|
||||
if (target == ProcLib::FLYCV) {
|
||||
|
@@ -100,10 +100,13 @@ class FASTDEPLOY_DECL Processor {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT);
|
||||
virtual bool operator()(FDMat* mat);
|
||||
|
||||
virtual bool operator()(FDMatBatch* mat_batch,
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
// This function is for backward compatibility, will be removed in the near
|
||||
// future, please use operator()(FDMat* mat) instead and set proc_lib in mat.
|
||||
virtual bool operator()(FDMat* mat, ProcLib lib);
|
||||
|
||||
virtual bool operator()(FDMatBatch* mat_batch);
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -14,12 +14,6 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/center_crop.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpCustomCrop.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -75,9 +69,8 @@ bool CenterCrop::ImplByCvCuda(FDMat* mat) {
|
||||
|
||||
int offset_x = static_cast<int>((mat->Width() - width_) / 2);
|
||||
int offset_y = static_cast<int>((mat->Height() - height_) / 2);
|
||||
cvcuda::CustomCrop crop_op;
|
||||
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
|
||||
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi);
|
||||
cvcuda_crop_op_(mat->Stream(), src_tensor, dst_tensor, crop_roi);
|
||||
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->SetWidth(width_);
|
||||
|
@@ -15,6 +15,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpCustomCrop.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -38,6 +43,9 @@ class FASTDEPLOY_DECL CenterCrop : public Processor {
|
||||
private:
|
||||
int height_;
|
||||
int width_;
|
||||
#ifdef ENABLE_CVCUDA
|
||||
cvcuda::CustomCrop cvcuda_crop_op_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -31,14 +31,14 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
|
||||
}
|
||||
FDASSERT(cudaStreamCreate(&stream_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while creating cuda stream.");
|
||||
DefaultProcLib::default_lib = ProcLib::CUDA;
|
||||
proc_lib_ = ProcLib::CUDA;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
|
||||
if (enable_cv_cuda) {
|
||||
#ifdef ENABLE_CVCUDA
|
||||
DefaultProcLib::default_lib = ProcLib::CVCUDA;
|
||||
proc_lib_ = ProcLib::CVCUDA;
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
|
||||
#endif
|
||||
@@ -46,16 +46,11 @@ void ProcessorManager::UseCuda(bool enable_cv_cuda, int gpu_id) {
|
||||
}
|
||||
|
||||
bool ProcessorManager::CudaUsed() {
|
||||
return (DefaultProcLib::default_lib == ProcLib::CUDA ||
|
||||
DefaultProcLib::default_lib == ProcLib::CVCUDA);
|
||||
return (proc_lib_ == ProcLib::CUDA || proc_lib_ == ProcLib::CVCUDA);
|
||||
}
|
||||
|
||||
bool ProcessorManager::Run(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0."
|
||||
<< std::endl;
|
||||
@@ -70,6 +65,7 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
|
||||
FDMatBatch image_batch(images);
|
||||
image_batch.input_cache = &batch_input_cache_;
|
||||
image_batch.output_cache = &batch_output_cache_;
|
||||
image_batch.proc_lib = proc_lib_;
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
if (CudaUsed()) {
|
||||
|
@@ -17,6 +17,7 @@
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/vision/common/processors/mat_batch.h"
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -78,7 +79,7 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
std::vector<FDTensor>* outputs) = 0;
|
||||
|
||||
protected:
|
||||
bool initialized_ = false;
|
||||
ProcLib proc_lib_ = ProcLib::DEFAULT;
|
||||
|
||||
private:
|
||||
#ifdef WITH_GPU
|
||||
|
@@ -145,6 +145,7 @@ struct FASTDEPLOY_DECL Mat {
|
||||
ProcLib mat_type = ProcLib::OPENCV;
|
||||
Layout layout = Layout::HWC;
|
||||
Device device = Device::CPU;
|
||||
ProcLib proc_lib = ProcLib::DEFAULT;
|
||||
|
||||
// Create FD Mat from FD Tensor. This method only create a
|
||||
// new FD Mat with zero copy and it's data pointer is reference
|
||||
|
@@ -67,6 +67,7 @@ FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
|
||||
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
|
||||
(*mats)[i].SetTensor(tensor);
|
||||
}
|
||||
mat_batch->device = Device::GPU;
|
||||
return mat_batch->Tensor();
|
||||
} else {
|
||||
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
|
||||
|
@@ -60,6 +60,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
|
||||
ProcLib mat_type = ProcLib::OPENCV;
|
||||
FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
|
||||
Device device = Device::CPU;
|
||||
ProcLib proc_lib = ProcLib::DEFAULT;
|
||||
|
||||
// False: the data is stored in the mats separately
|
||||
// True: the data is stored in the fd_tensor continuously in 4 dimensions
|
||||
|
@@ -85,6 +85,8 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
|
||||
// NHWC -> NCHW
|
||||
std::swap(mat_batch->output_cache->shape[1],
|
||||
mat_batch->output_cache->shape[3]);
|
||||
std::swap(mat_batch->output_cache->shape[2],
|
||||
mat_batch->output_cache->shape[3]);
|
||||
|
||||
// Copy alpha and beta to GPU
|
||||
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
|
||||
|
@@ -91,6 +91,60 @@ bool Pad::ImplByFlyCV(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool Pad::ImplByCvCuda(FDMat* mat) {
|
||||
if (mat->layout != Layout::HWC) {
|
||||
FDERROR << "Pad: The input data must be Layout::HWC format!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (mat->Channels() > 4) {
|
||||
FDERROR << "Pad: Only support channels <= 4." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (mat->Channels() != value_.size()) {
|
||||
FDERROR << "Pad: Require input channels equals to size of padding value, "
|
||||
"but now channels = "
|
||||
<< mat->Channels()
|
||||
<< ", the size of padding values = " << value_.size() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
float4 value;
|
||||
if (value_.size() == 1) {
|
||||
value = make_float4(value_[0], 0.0f, 0.0f, 0.0f);
|
||||
} else if (value_.size() == 2) {
|
||||
value = make_float4(value_[0], value_[1], 0.0f, 0.0f);
|
||||
} else if (value_.size() == 3) {
|
||||
value = make_float4(value_[0], value_[1], value_[2], 0.0f);
|
||||
} else {
|
||||
value = make_float4(value_[0], value_[1], value_[2], value_[3]);
|
||||
}
|
||||
|
||||
// Prepare input tensor
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
int height = mat->Height() + top_ + bottom_;
|
||||
int width = mat->Height() + left_ + right_;
|
||||
|
||||
// Prepare output tensor
|
||||
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
|
||||
"output_cache", Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
cvcuda_pad_op_(mat->Stream(), src_tensor, dst_tensor, top_, left_,
|
||||
NVCV_BORDER_CONSTANT, value);
|
||||
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->SetWidth(width);
|
||||
mat->SetHeight(height);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool Pad::Run(Mat* mat, const int& top, const int& bottom, const int& left,
|
||||
const int& right, const std::vector<float>& value, ProcLib lib) {
|
||||
auto p = Pad(top, bottom, left, right, value);
|
||||
|
@@ -15,6 +15,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpCopyMakeBorder.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -32,6 +37,9 @@ class FASTDEPLOY_DECL Pad : public Processor {
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(FDMat* mat);
|
||||
#endif
|
||||
std::string Name() { return "Pad"; }
|
||||
|
||||
@@ -39,12 +47,23 @@ class FASTDEPLOY_DECL Pad : public Processor {
|
||||
const int& right, const std::vector<float>& value,
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
bool SetPaddingSize(int top, int bottom, int left, int right) {
|
||||
top_ = top;
|
||||
bottom_ = bottom;
|
||||
left_ = left;
|
||||
right_ = right;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
int top_;
|
||||
int bottom_;
|
||||
int left_;
|
||||
int right_;
|
||||
std::vector<float> value_;
|
||||
#ifdef ENABLE_CVCUDA
|
||||
cvcuda::CopyMakeBorder cvcuda_pad_op_;
|
||||
#endif
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -14,12 +14,6 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/resize.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -152,8 +146,7 @@ bool Resize::ImplByCvCuda(FDMat* mat) {
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(mat->output_cache);
|
||||
|
@@ -15,6 +15,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -61,6 +66,9 @@ class FASTDEPLOY_DECL Resize : public Processor {
|
||||
float scale_h_ = -1.0;
|
||||
int interp_ = 1;
|
||||
bool use_scale_ = false;
|
||||
#ifdef ENABLE_CVCUDA
|
||||
cvcuda::Resize cvcuda_resize_op_;
|
||||
#endif
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -14,12 +14,6 @@
|
||||
|
||||
#include "fastdeploy/vision/common/processors/resize_by_short.h"
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
@@ -102,8 +96,7 @@ bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
cvcuda_resize_op_(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(mat->output_cache);
|
||||
@@ -144,8 +137,7 @@ bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||
CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch);
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat_batch->Stream(), src_batch, dst_batch,
|
||||
cvcuda_resize_op_(mat_batch->Stream(), src_batch, dst_batch,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
|
@@ -15,6 +15,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/vision/common/processors/base.h"
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include <cvcuda/OpResize.hpp>
|
||||
|
||||
#include "fastdeploy/vision/common/processors/cvcuda_utils.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -49,6 +54,9 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
|
||||
std::vector<int> max_hw_;
|
||||
int interp_;
|
||||
bool use_scale_;
|
||||
#ifdef ENABLE_CVCUDA
|
||||
cvcuda::Resize cvcuda_resize_op_;
|
||||
#endif
|
||||
};
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
25
fastdeploy/vision/ocr/ppocr/dbdetector.cc
Executable file → Normal file
25
fastdeploy/vision/ocr/ppocr/dbdetector.cc
Executable file → Normal file
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
@@ -26,11 +27,11 @@ DBDetector::DBDetector(const std::string& model_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) {
|
||||
if (model_format == ModelFormat::ONNX) {
|
||||
valid_cpu_backends = {Backend::ORT,
|
||||
Backend::OPENVINO};
|
||||
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE};
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO,
|
||||
Backend::LITE};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||
valid_kunlunxin_backends = {Backend::LITE};
|
||||
valid_ascend_backends = {Backend::LITE};
|
||||
@@ -54,7 +55,8 @@ bool DBDetector::Initialize() {
|
||||
}
|
||||
|
||||
std::unique_ptr<DBDetector> DBDetector::Clone() const {
|
||||
std::unique_ptr<DBDetector> clone_model = utils::make_unique<DBDetector>(DBDetector(*this));
|
||||
std::unique_ptr<DBDetector> clone_model =
|
||||
utils::make_unique<DBDetector>(DBDetector(*this));
|
||||
clone_model->SetRuntime(clone_model->CloneRuntime());
|
||||
return clone_model;
|
||||
}
|
||||
@@ -69,14 +71,15 @@ bool DBDetector::Predict(const cv::Mat& img,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
bool DBDetector::BatchPredict(
|
||||
const std::vector<cv::Mat>& images,
|
||||
std::vector<std::vector<std::array<int, 8>>>* det_results) {
|
||||
std::vector<FDMat> fd_images = WrapMat(images);
|
||||
std::vector<std::array<int, 4>> batch_det_img_info;
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) {
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto batch_det_img_info = preprocessor_.GetBatchImgInfo();
|
||||
|
||||
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||
@@ -84,13 +87,15 @@ bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) {
|
||||
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
|
||||
if (!postprocessor_.Run(reused_output_tensors_, det_results,
|
||||
*batch_det_img_info)) {
|
||||
FDERROR << "Failed to postprocess the inference cls_results by runtime."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namesapce ocr
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -13,9 +13,8 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
#include "fastdeploy/function/concat.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -45,58 +44,55 @@ std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
|
||||
*ratio_w = float(resize_w) / float(w);
|
||||
*/
|
||||
}
|
||||
bool OcrDetectorResizeImage(FDMat* img,
|
||||
int resize_w,
|
||||
int resize_h,
|
||||
int max_resize_w,
|
||||
int max_resize_h) {
|
||||
Resize::Run(img, resize_w, resize_h);
|
||||
|
||||
DBDetectorPreprocessor::DBDetectorPreprocessor() {
|
||||
resize_op_ = std::make_shared<Resize>(-1, -1);
|
||||
|
||||
std::vector<float> value = {0, 0, 0};
|
||||
Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value);
|
||||
pad_op_ = std::make_shared<Pad>(0, 0, 0, 0, value);
|
||||
|
||||
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> std = {0.229f, 0.224f, 0.225f};
|
||||
bool is_scale = true;
|
||||
normalize_permute_op_ =
|
||||
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
|
||||
}
|
||||
|
||||
bool DBDetectorPreprocessor::ResizeImage(FDMat* img, int resize_w, int resize_h,
|
||||
int max_resize_w, int max_resize_h) {
|
||||
resize_op_->SetWidthAndHeight(resize_w, resize_h);
|
||||
(*resize_op_)(img);
|
||||
|
||||
pad_op_->SetPaddingSize(0, max_resize_h - resize_h, 0,
|
||||
max_resize_w - resize_w);
|
||||
(*pad_op_)(img);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images,
|
||||
std::vector<FDTensor>* outputs,
|
||||
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||
return false;
|
||||
}
|
||||
bool DBDetectorPreprocessor::Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
int max_resize_w = 0;
|
||||
int max_resize_h = 0;
|
||||
std::vector<std::array<int, 4>>& batch_det_img_info = *batch_det_img_info_ptr;
|
||||
batch_det_img_info.clear();
|
||||
batch_det_img_info.resize(images->size());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
FDMat* mat = &(images->at(i));
|
||||
batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_);
|
||||
max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]);
|
||||
max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]);
|
||||
batch_det_img_info_.clear();
|
||||
batch_det_img_info_.resize(image_batch->mats->size());
|
||||
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(image_batch->mats->at(i));
|
||||
batch_det_img_info_[i] = OcrDetectorGetInfo(mat, max_side_len_);
|
||||
max_resize_w = std::max(max_resize_w, batch_det_img_info_[i][2]);
|
||||
max_resize_h = std::max(max_resize_h, batch_det_img_info_[i][3]);
|
||||
}
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
FDMat* mat = &(images->at(i));
|
||||
OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h);
|
||||
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
|
||||
/*
|
||||
Normalize::Run(mat, mean_, scale_, is_scale_);
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
*/
|
||||
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(image_batch->mats->at(i));
|
||||
ResizeImage(mat, batch_det_img_info_[i][2], batch_det_img_info_[i][3],
|
||||
max_resize_w, max_resize_h);
|
||||
}
|
||||
// Only have 1 output Tensor.
|
||||
(*normalize_permute_op_)(image_batch);
|
||||
|
||||
outputs->resize(1);
|
||||
// Concat all the preprocessed data to a batch tensor
|
||||
std::vector<FDTensor> tensors(images->size());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
(*images)[i].ShareWithTensor(&(tensors[i]));
|
||||
tensors[i].ExpandDim(0);
|
||||
}
|
||||
if (tensors.size() == 1) {
|
||||
(*outputs)[0] = std::move(tensors[0]);
|
||||
} else {
|
||||
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
FDTensor* tensor = image_batch->Tensor();
|
||||
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
|
||||
tensor->Data(), tensor->device,
|
||||
tensor->device_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@@ -13,7 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/processors/manager.h"
|
||||
#include "fastdeploy/vision/common/processors/resize.h"
|
||||
#include "fastdeploy/vision/common/processors/pad.h"
|
||||
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -22,43 +25,48 @@ namespace vision {
|
||||
namespace ocr {
|
||||
/*! @brief Preprocessor object for DBDetector serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL DBDetectorPreprocessor {
|
||||
class FASTDEPLOY_DECL DBDetectorPreprocessor : public ProcessorManager {
|
||||
public:
|
||||
DBDetectorPreprocessor();
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input data list, all the elements are FDMat
|
||||
* \param[in] image_batch The input image batch
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \param[in] batch_det_img_info_ptr The output of preprocess
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
|
||||
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
|
||||
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
|
||||
|
||||
/// Set max_side_len for the detection preprocess, default is 960
|
||||
void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; }
|
||||
|
||||
/// Get max_side_len of the detection preprocess
|
||||
int GetMaxSideLen() const { return max_side_len_; }
|
||||
|
||||
/// Set mean value for the image normalization in detection preprocess
|
||||
void SetMean(const std::vector<float>& mean) { mean_ = mean; }
|
||||
/// Get mean value of the image normalization in detection preprocess
|
||||
std::vector<float> GetMean() const { return mean_; }
|
||||
/// Set preprocess normalize parameters, please call this API to customize
|
||||
/// the normalize parameters, otherwise it will use the default normalize
|
||||
/// parameters.
|
||||
void SetNormalize(const std::vector<float>& mean = {0.485f, 0.456f, 0.406f},
|
||||
const std::vector<float>& std = {0.229f, 0.224f, 0.225f},
|
||||
bool is_scale = true) {
|
||||
normalize_permute_op_ =
|
||||
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
|
||||
}
|
||||
|
||||
/// Set scale value for the image normalization in detection preprocess
|
||||
void SetScale(const std::vector<float>& scale) { scale_ = scale; }
|
||||
/// Get scale value of the image normalization in detection preprocess
|
||||
std::vector<float> GetScale() const { return scale_; }
|
||||
|
||||
/// Set is_scale for the image normalization in detection preprocess
|
||||
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
|
||||
/// Get is_scale of the image normalization in detection preprocess
|
||||
bool GetIsScale() const { return is_scale_; }
|
||||
/// Get the image info of the last batch, return a list of array
|
||||
/// {image width, image height, resize width, resize height}
|
||||
const std::vector<std::array<int, 4>>* GetBatchImgInfo() {
|
||||
return &batch_det_img_info_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool ResizeImage(FDMat* img, int resize_w, int resize_h, int max_resize_w,
|
||||
int max_resize_h);
|
||||
int max_side_len_ = 960;
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
|
||||
bool is_scale_ = true;
|
||||
std::vector<std::array<int, 4>> batch_det_img_info_;
|
||||
std::shared_ptr<Resize> resize_op_;
|
||||
std::shared_ptr<Pad> pad_op_;
|
||||
std::shared_ptr<NormalizeAndPermute> normalize_permute_op_;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
|
182
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Executable file → Normal file
182
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Executable file → Normal file
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -22,52 +23,74 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
});
|
||||
|
||||
// DBDetector
|
||||
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
|
||||
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(
|
||||
m, "DBDetectorPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen)
|
||||
.def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
.def_property("max_side_len",
|
||||
&vision::ocr::DBDetectorPreprocessor::GetMaxSideLen,
|
||||
&vision::ocr::DBDetectorPreprocessor::SetMaxSideLen)
|
||||
.def("set_normalize",
|
||||
[](vision::ocr::DBDetectorPreprocessor& self,
|
||||
const std::vector<float>& mean, const std::vector<float>& std,
|
||||
bool is_scale) { self.SetNormalize(mean, std, is_scale); })
|
||||
.def("run", [](vision::ocr::DBDetectorPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
std::vector<std::array<int, 4>> batch_det_img_info;
|
||||
self.Run(&images, &outputs, &batch_det_img_info);
|
||||
self.Run(&images, &outputs);
|
||||
auto batch_det_img_info = self.GetBatchImgInfo();
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
return std::make_pair(outputs, batch_det_img_info);
|
||||
return std::make_pair(outputs, *batch_det_img_info);
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
|
||||
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(
|
||||
m, "DBDetectorPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh)
|
||||
.def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh)
|
||||
.def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio)
|
||||
.def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode)
|
||||
.def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation)
|
||||
.def_property("det_db_thresh",
|
||||
&vision::ocr::DBDetectorPostprocessor::GetDetDBThresh,
|
||||
&vision::ocr::DBDetectorPostprocessor::SetDetDBThresh)
|
||||
.def_property("det_db_box_thresh",
|
||||
&vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh,
|
||||
&vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh)
|
||||
.def_property("det_db_unclip_ratio",
|
||||
&vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio,
|
||||
&vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio)
|
||||
.def_property("det_db_score_mode",
|
||||
&vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode,
|
||||
&vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode)
|
||||
.def_property("use_dilation",
|
||||
&vision::ocr::DBDetectorPostprocessor::GetUseDilation,
|
||||
&vision::ocr::DBDetectorPostprocessor::SetUseDilation)
|
||||
|
||||
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
|
||||
.def("run",
|
||||
[](vision::ocr::DBDetectorPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs,
|
||||
const std::vector<std::array<int, 4>>& batch_det_img_info) {
|
||||
std::vector<std::vector<std::array<int, 8>>> results;
|
||||
|
||||
if (!self.Run(inputs, &results, batch_det_img_info)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"DBDetectorPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
|
||||
.def("run",
|
||||
[](vision::ocr::DBDetectorPostprocessor& self,
|
||||
std::vector<pybind11::array>& input_array,
|
||||
const std::vector<std::array<int, 4>>& batch_det_img_info) {
|
||||
std::vector<std::vector<std::array<int, 8>>> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results, batch_det_img_info)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"DBDetectorPostprocessor.");
|
||||
}
|
||||
return results;
|
||||
});
|
||||
@@ -76,16 +99,19 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::DBDetector& self,
|
||||
pybind11::array& data) {
|
||||
.def_property_readonly("preprocessor",
|
||||
&vision::ocr::DBDetector::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor",
|
||||
&vision::ocr::DBDetector::GetPostprocessor)
|
||||
.def("predict",
|
||||
[](vision::ocr::DBDetector& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
std::vector<std::array<int, 8>> boxes_result;
|
||||
self.Predict(mat, &boxes_result);
|
||||
return boxes_result;
|
||||
})
|
||||
.def("batch_predict", [](vision::ocr::DBDetector& self, std::vector<pybind11::array>& data) {
|
||||
.def("batch_predict", [](vision::ocr::DBDetector& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
std::vector<std::vector<std::array<int, 8>>> det_results;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
@@ -96,20 +122,29 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
});
|
||||
|
||||
// Classifier
|
||||
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
|
||||
pybind11::class_<vision::ocr::ClassifierPreprocessor>(
|
||||
m, "ClassifierPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape)
|
||||
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
.def_property("cls_image_shape",
|
||||
&vision::ocr::ClassifierPreprocessor::GetClsImageShape,
|
||||
&vision::ocr::ClassifierPreprocessor::SetClsImageShape)
|
||||
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean,
|
||||
&vision::ocr::ClassifierPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale,
|
||||
&vision::ocr::ClassifierPreprocessor::SetScale)
|
||||
.def_property("is_scale",
|
||||
&vision::ocr::ClassifierPreprocessor::GetIsScale,
|
||||
&vision::ocr::ClassifierPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::ClassifierPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in ClassifierPreprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in ClassifierPreprocessor.");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
@@ -117,15 +152,21 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
return outputs;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
|
||||
pybind11::class_<vision::ocr::ClassifierPostprocessor>(
|
||||
m, "ClassifierPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh)
|
||||
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
|
||||
.def_property("cls_thresh",
|
||||
&vision::ocr::ClassifierPostprocessor::GetClsThresh,
|
||||
&vision::ocr::ClassifierPostprocessor::SetClsThresh)
|
||||
.def("run",
|
||||
[](vision::ocr::ClassifierPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<int> cls_labels;
|
||||
std::vector<float> cls_scores;
|
||||
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"ClassifierPostprocessor.");
|
||||
}
|
||||
return std::make_pair(cls_labels, cls_scores);
|
||||
})
|
||||
@@ -136,7 +177,9 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
std::vector<int> cls_labels;
|
||||
std::vector<float> cls_scores;
|
||||
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"ClassifierPostprocessor.");
|
||||
}
|
||||
return std::make_pair(cls_labels, cls_scores);
|
||||
});
|
||||
@@ -145,17 +188,20 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::Classifier& self,
|
||||
pybind11::array& data) {
|
||||
.def_property_readonly("preprocessor",
|
||||
&vision::ocr::Classifier::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor",
|
||||
&vision::ocr::Classifier::GetPostprocessor)
|
||||
.def("predict",
|
||||
[](vision::ocr::Classifier& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
int32_t cls_label;
|
||||
float cls_score;
|
||||
self.Predict(mat, &cls_label, &cls_score);
|
||||
return std::make_pair(cls_label, cls_score);
|
||||
})
|
||||
.def("batch_predict", [](vision::ocr::Classifier& self, std::vector<pybind11::array>& data) {
|
||||
.def("batch_predict", [](vision::ocr::Classifier& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
std::vector<int32_t> cls_labels;
|
||||
std::vector<float> cls_scores;
|
||||
@@ -167,21 +213,32 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
});
|
||||
|
||||
// Recognizer
|
||||
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
|
||||
pybind11::class_<vision::ocr::RecognizerPreprocessor>(
|
||||
m, "RecognizerPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer)
|
||||
.def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape)
|
||||
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
.def_property("static_shape_infer",
|
||||
&vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer,
|
||||
&vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer)
|
||||
.def_property("rec_image_shape",
|
||||
&vision::ocr::RecognizerPreprocessor::GetRecImageShape,
|
||||
&vision::ocr::RecognizerPreprocessor::SetRecImageShape)
|
||||
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean,
|
||||
&vision::ocr::RecognizerPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale,
|
||||
&vision::ocr::RecognizerPreprocessor::SetScale)
|
||||
.def_property("is_scale",
|
||||
&vision::ocr::RecognizerPreprocessor::GetIsScale,
|
||||
&vision::ocr::RecognizerPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::RecognizerPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in RecognizerPreprocessor.");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
@@ -189,14 +246,18 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
return outputs;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor")
|
||||
pybind11::class_<vision::ocr::RecognizerPostprocessor>(
|
||||
m, "RecognizerPostprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
|
||||
.def("run",
|
||||
[](vision::ocr::RecognizerPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<std::string> texts;
|
||||
std::vector<float> rec_scores;
|
||||
if (!self.Run(inputs, &texts, &rec_scores)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"RecognizerPostprocessor.");
|
||||
}
|
||||
return std::make_pair(texts, rec_scores);
|
||||
})
|
||||
@@ -207,7 +268,9 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
std::vector<std::string> texts;
|
||||
std::vector<float> rec_scores;
|
||||
if (!self.Run(inputs, &texts, &rec_scores)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"RecognizerPostprocessor.");
|
||||
}
|
||||
return std::make_pair(texts, rec_scores);
|
||||
});
|
||||
@@ -216,17 +279,20 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::Recognizer& self,
|
||||
pybind11::array& data) {
|
||||
.def_property_readonly("preprocessor",
|
||||
&vision::ocr::Recognizer::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor",
|
||||
&vision::ocr::Recognizer::GetPostprocessor)
|
||||
.def("predict",
|
||||
[](vision::ocr::Recognizer& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
std::string text;
|
||||
float rec_score;
|
||||
self.Predict(mat, &text, &rec_score);
|
||||
return std::make_pair(text, rec_score);
|
||||
})
|
||||
.def("batch_predict", [](vision::ocr::Recognizer& self, std::vector<pybind11::array>& data) {
|
||||
.def("batch_predict", [](vision::ocr::Recognizer& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
std::vector<std::string> texts;
|
||||
std::vector<float> rec_scores;
|
||||
|
@@ -46,7 +46,6 @@ class PaddleClasPreprocessor(ProcessorManager):
|
||||
When the initial operator is Resize, and input image size is large,
|
||||
maybe it's better to run resize on CPU, because the HostToDevice memcpy
|
||||
is time consuming. Set this True to run the initial resize on CPU.
|
||||
|
||||
:param: v: True or False
|
||||
"""
|
||||
self._manager.initial_resize_on_cpu(v)
|
||||
|
@@ -37,43 +37,31 @@ class DBDetectorPreprocessor:
|
||||
|
||||
@property
|
||||
def max_side_len(self):
|
||||
"""Get max_side_len value.
|
||||
"""
|
||||
return self._preprocessor.max_side_len
|
||||
|
||||
@max_side_len.setter
|
||||
def max_side_len(self, value):
|
||||
"""Set max_side_len value.
|
||||
:param: value: (int) max_side_len value
|
||||
"""
|
||||
assert isinstance(
|
||||
value, int), "The value to set `max_side_len` must be type of int."
|
||||
self._preprocessor.max_side_len = value
|
||||
|
||||
@property
|
||||
def is_scale(self):
|
||||
return self._preprocessor.is_scale
|
||||
|
||||
@is_scale.setter
|
||||
def is_scale(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_scale` must be type of bool."
|
||||
self._preprocessor.is_scale = value
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self._preprocessor.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value):
|
||||
assert isinstance(
|
||||
value, list), "The value to set `scale` must be type of list."
|
||||
self._preprocessor.scale = value
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._preprocessor.mean
|
||||
|
||||
@mean.setter
|
||||
def mean(self, value):
|
||||
assert isinstance(
|
||||
value, list), "The value to set `mean` must be type of list."
|
||||
self._preprocessor.mean = value
|
||||
def set_normalize(self,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
is_scale=True):
|
||||
"""Set preprocess normalize parameters, please call this API to
|
||||
customize the normalize parameters, otherwise it will use the default
|
||||
normalize parameters.
|
||||
:param: mean: (list of float) mean values
|
||||
:param: std: (list of float) std values
|
||||
:param: is_scale: (boolean) whether to scale
|
||||
"""
|
||||
self._preprocessor.set_normalize(mean, std, is_scale)
|
||||
|
||||
|
||||
class DBDetectorPostprocessor:
|
||||
@@ -174,6 +162,7 @@ class DBDetector(FastDeployModel):
|
||||
"""Clone OCR detection model object
|
||||
:return: a new OCR detection model object
|
||||
"""
|
||||
|
||||
class DBDetectorClone(DBDetector):
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
@@ -203,18 +192,10 @@ class DBDetector(FastDeployModel):
|
||||
def preprocessor(self):
|
||||
return self._model.preprocessor
|
||||
|
||||
@preprocessor.setter
|
||||
def preprocessor(self, value):
|
||||
self._model.preprocessor = value
|
||||
|
||||
@property
|
||||
def postprocessor(self):
|
||||
return self._model.postprocessor
|
||||
|
||||
@postprocessor.setter
|
||||
def postprocessor(self, value):
|
||||
self._model.postprocessor = value
|
||||
|
||||
# Det Preprocessor Property
|
||||
@property
|
||||
def max_side_len(self):
|
||||
@@ -226,36 +207,6 @@ class DBDetector(FastDeployModel):
|
||||
value, int), "The value to set `max_side_len` must be type of int."
|
||||
self._model.preprocessor.max_side_len = value
|
||||
|
||||
@property
|
||||
def is_scale(self):
|
||||
return self._model.preprocessor.is_scale
|
||||
|
||||
@is_scale.setter
|
||||
def is_scale(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_scale` must be type of bool."
|
||||
self._model.preprocessor.is_scale = value
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self._model.preprocessor.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value):
|
||||
assert isinstance(
|
||||
value, list), "The value to set `scale` must be type of list."
|
||||
self._model.preprocessor.scale = value
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._model.preprocessor.mean
|
||||
|
||||
@mean.setter
|
||||
def mean(self, value):
|
||||
assert isinstance(
|
||||
value, list), "The value to set `mean` must be type of list."
|
||||
self._model.preprocessor.mean = value
|
||||
|
||||
# Det Ppstprocessor Property
|
||||
@property
|
||||
def det_db_thresh(self):
|
||||
@@ -421,6 +372,7 @@ class Classifier(FastDeployModel):
|
||||
"""Clone OCR classification model object
|
||||
:return: a new OCR classification model object
|
||||
"""
|
||||
|
||||
class ClassifierClone(Classifier):
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
@@ -629,6 +581,7 @@ class Recognizer(FastDeployModel):
|
||||
"""Clone OCR recognition model object
|
||||
:return: a new OCR recognition model object
|
||||
"""
|
||||
|
||||
class RecognizerClone(Recognizer):
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
@@ -743,6 +696,7 @@ class PPOCRv3(FastDeployModel):
|
||||
"""Clone PPOCRv3 pipeline object
|
||||
:return: a new PPOCRv3 pipeline object
|
||||
"""
|
||||
|
||||
class PPOCRv3Clone(PPOCRv3):
|
||||
def __init__(self, system):
|
||||
self.system_ = system
|
||||
@@ -818,6 +772,7 @@ class PPOCRv2(FastDeployModel):
|
||||
"""Clone PPOCRv3 pipeline object
|
||||
:return: a new PPOCRv3 pipeline object
|
||||
"""
|
||||
|
||||
class PPOCRv2Clone(PPOCRv2):
|
||||
def __init__(self, system):
|
||||
self.system_ = system
|
||||
|
Reference in New Issue
Block a user