[Bug Fix] add ocr new feature and fix codestyle (#764)

* fix ocr bug and add new feature

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* add property

* add test

* fix code style

* fix bug

* fix bug

* fix bug

* fix port

* fix ocr

* fix_ocr

* fix ocr

* fix ocr

* fix ocr

* fix ocr

* Update paddle2onnx.cmake

* Update paddle2onnx.cmake

* Update paddle2onnx.cmake

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: Jason <928090362@qq.com>
This commit is contained in:
Thomas Young
2022-12-07 19:31:54 +08:00
committed by GitHub
parent e6af8f2334
commit 5df62485c3
33 changed files with 1222 additions and 376 deletions

View File

@@ -43,7 +43,7 @@ else()
endif(WIN32) endif(WIN32)
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
set(PADDLE2ONNX_VERSION "1.0.4") set(PADDLE2ONNX_VERSION "1.0.5")
if(WIN32) if(WIN32)
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
if(NOT CMAKE_CL_64) if(NOT CMAKE_CL_64)

2
examples/vision/ocr/PP-OCRv3/serving/README.md Normal file → Executable file
View File

@@ -47,6 +47,8 @@ tar xvf ch_PP-OCRv3_rec_infer.tar && mv ch_PP-OCRv3_rec_infer 1
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
mv 1 models/rec_runtime/ && rm -rf ch_PP-OCRv3_rec_infer.tar mv 1 models/rec_runtime/ && rm -rf ch_PP-OCRv3_rec_infer.tar
mkdir models/pp_ocr/1 && mkdir models/rec_pp/1 && mkdir models/cls_pp/1
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
mv ppocr_keys_v1.txt models/rec_postprocess/1/ mv ppocr_keys_v1.txt models/rec_postprocess/1/

2
examples/vision/ocr/PP-OCRv3/serving/client.py Normal file → Executable file
View File

@@ -91,7 +91,7 @@ class SyncGRPCTritonRunner:
if __name__ == "__main__": if __name__ == "__main__":
model_name = "pp_ocr" model_name = "pp_ocr"
model_version = "1" model_version = "1"
url = "localhost:9001" url = "localhost:8001"
runner = SyncGRPCTritonRunner(url, model_name, model_version) runner = SyncGRPCTritonRunner(url, model_name, model_version)
im = cv2.imread("12.jpg") im = cv2.imread("12.jpg")
im = np.array([im, ]) im = np.array([im, ])

View File

@@ -35,3 +35,18 @@ instance_group [
gpus: [0] gpus: [0]
} }
] ]
optimization {
execution_accelerators {
# GPU推理配置 配合KIND_GPU使用
gpu_execution_accelerator : [
{
name : "paddle"
# 设置推理并行计算线程数为4
parameters { key: "cpu_threads" value: "4" }
# 开启mkldnn加速设置为0关闭mkldnn
parameters { key: "use_mkldnn" value: "1" }
}
]
}
}

View File

@@ -35,3 +35,18 @@ instance_group [
gpus: [0] gpus: [0]
} }
] ]
optimization {
execution_accelerators {
# GPU推理配置 配合KIND_GPU使用
gpu_execution_accelerator : [
{
name : "paddle"
# 设置推理并行计算线程数为4
parameters { key: "cpu_threads" value: "4" }
# 开启mkldnn加速设置为0关闭mkldnn
parameters { key: "use_mkldnn" value: "1" }
}
]
}
}

View File

@@ -47,8 +47,6 @@ class TritonPythonModel:
* model_version: Model version * model_version: Model version
* model_name: Model name * model_name: Model name
""" """
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
print(sys.getdefaultencoding())
# You must parse model_config. JSON string is not parsed here # You must parse model_config. JSON string is not parsed here
self.model_config = json.loads(args['model_config']) self.model_config = json.loads(args['model_config'])
print("model_config:", self.model_config) print("model_config:", self.model_config)

View File

@@ -35,3 +35,18 @@ instance_group [
gpus: [0] gpus: [0]
} }
] ]
optimization {
execution_accelerators {
# GPU推理配置 配合KIND_GPU使用
gpu_execution_accelerator : [
{
name : "paddle"
# 设置推理并行计算线程数为4
parameters { key: "cpu_threads" value: "4" }
# 开启mkldnn加速设置为0关闭mkldnn
parameters { key: "use_mkldnn" value: "1" }
}
]
}
}

View File

@@ -50,10 +50,29 @@ bool Classifier::Initialize() {
return true; return true;
} }
bool Classifier::Predict(cv::Mat& img, int32_t* cls_label, float* cls_score) {
std::vector<int32_t> cls_labels(1);
std::vector<float> cls_scores(1);
bool success = BatchPredict({img}, &cls_labels, &cls_scores);
if(!success){
return success;
}
*cls_label = cls_labels[0];
*cls_score = cls_scores[0];
return true;
}
bool Classifier::BatchPredict(const std::vector<cv::Mat>& images, bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores) { std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores) {
return BatchPredict(images, cls_labels, cls_scores, 0, images.size());
}
bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores,
size_t start_index, size_t end_index) {
size_t total_size = images.size();
std::vector<FDMat> fd_images = WrapMat(images); std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index)) {
FDERROR << "Failed to preprocess the input image." << std::endl; FDERROR << "Failed to preprocess the input image." << std::endl;
return false; return false;
} }
@@ -63,7 +82,7 @@ bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
return false; return false;
} }
if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) { if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores, start_index, total_size)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false; return false;
} }

View File

@@ -43,7 +43,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_cls"; } std::string ModelName() const { return "ppocr/ocr_cls"; }
virtual bool Predict(cv::Mat& img, int32_t* cls_label, float* cls_score);
/** \brief BatchPredict the input image and get OCR classification model cls_result. /** \brief BatchPredict the input image and get OCR classification model cls_result.
* *
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -53,6 +53,10 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels, std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores); std::vector<float>* cls_scores);
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores,
size_t start_index, size_t end_index);
ClassifierPreprocessor preprocessor_; ClassifierPreprocessor preprocessor_;
ClassifierPostprocessor postprocessor_; ClassifierPostprocessor postprocessor_;

View File

@@ -20,10 +20,6 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
ClassifierPostprocessor::ClassifierPostprocessor() {
initialized_ = true;
}
bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) { bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) {
*cls_label = std::distance( *cls_label = std::distance(
@@ -37,10 +33,14 @@ bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int*
bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors, bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels, std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores) { std::vector<float>* cls_scores) {
if (!initialized_) { size_t total_size = tensors[0].shape[0];
FDERROR << "Postprocessor is not initialized." << std::endl; return Run(tensors, cls_labels, cls_scores, 0, total_size);
return false; }
}
bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores,
size_t start_index, size_t total_size) {
// Classifier have only 1 output tensor. // Classifier have only 1 output tensor.
const FDTensor& tensor = tensors[0]; const FDTensor& tensor = tensors[0];
@@ -48,13 +48,29 @@ bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
size_t batch = tensor.shape[0]; size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>()); size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
cls_labels->resize(batch); if (batch <= 0) {
cls_scores->resize(batch); FDERROR << "The infer outputTensor.shape[0] <=0, wrong infer result." << std::endl;
return false;
}
if (start_index < 0 || total_size <= 0) {
FDERROR << "start_index or total_size error. Correct is: 0 <= start_index < total_size" << std::endl;
return false;
}
if ((start_index + batch) > total_size) {
FDERROR << "start_index or total_size error. Correct is: start_index + batch(outputTensor.shape[0]) <= total_size" << std::endl;
return false;
}
cls_labels->resize(total_size);
cls_scores->resize(total_size);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data()); const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) { for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false; if(!SingleBatchPostprocessor(tensor_data+ i_batch * length,
tensor_data = tensor_data + length; length,
&cls_labels->at(i_batch + start_index),
&cls_scores->at(i_batch + start_index))) {
return false;
}
} }
return true; return true;

View File

@@ -25,11 +25,6 @@ namespace ocr {
*/ */
class FASTDEPLOY_DECL ClassifierPostprocessor { class FASTDEPLOY_DECL ClassifierPostprocessor {
public: public:
/** \brief Create a postprocessor instance for Classifier serials model
*
*/
ClassifierPostprocessor();
/** \brief Process the result of runtime and fill to ClassifyResult structure /** \brief Process the result of runtime and fill to ClassifyResult structure
* *
* \param[in] tensors The inference result from runtime * \param[in] tensors The inference result from runtime
@@ -40,10 +35,11 @@ class FASTDEPLOY_DECL ClassifierPostprocessor {
bool Run(const std::vector<FDTensor>& tensors, bool Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores); std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores);
float cls_thresh_ = 0.9; bool Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores,
size_t start_index, size_t total_size);
private: float cls_thresh_ = 0.9;
bool initialized_ = false;
}; };
} // namespace ocr } // namespace ocr

53
fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc Normal file → Executable file
View File

@@ -21,58 +21,53 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
ClassifierPreprocessor::ClassifierPreprocessor() {
initialized_ = true;
}
void OcrClassifierResizeImage(FDMat* mat, void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape) { const std::vector<int>& cls_image_shape) {
int imgC = cls_image_shape[0]; int img_c = cls_image_shape[0];
int imgH = cls_image_shape[1]; int img_h = cls_image_shape[1];
int imgW = cls_image_shape[2]; int img_w = cls_image_shape[2];
float ratio = float(mat->Width()) / float(mat->Height()); float ratio = float(mat->Width()) / float(mat->Height());
int resize_w; int resize_w;
if (ceilf(imgH * ratio) > imgW) if (ceilf(img_h * ratio) > img_w)
resize_w = imgW; resize_w = img_w;
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(img_h * ratio));
Resize::Run(mat, resize_w, imgH); Resize::Run(mat, resize_w, img_h);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
}
} }
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) { bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) { return Run(images, outputs, 0, images->size());
FDERROR << "The preprocessor is not initialized." << std::endl; }
return false;
} bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
if (images->size() == 0) { size_t start_index, size_t end_index) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
if (images->size() == 0 || start_index <0 || end_index <= start_index || end_index > images->size()) {
FDERROR << "images->size() or index error. Correct is: 0 <= start_index < end_index <= images->size()" << std::endl;
return false; return false;
} }
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = start_index; i < end_index; ++i) {
FDMat* mat = &(images->at(i)); FDMat* mat = &(images->at(i));
OcrClassifierResizeImage(mat, cls_image_shape_); OcrClassifierResizeImage(mat, cls_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_); Normalize::Run(mat, mean_, scale_, is_scale_);
std::vector<float> value = {0, 0, 0};
if (mat->Width() < cls_image_shape_[2]) {
Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value);
}
HWC2CHW::Run(mat); HWC2CHW::Run(mat);
Cast::Run(mat, "float"); Cast::Run(mat, "float");
*/
} }
// Only have 1 output Tensor. // Only have 1 output Tensor.
outputs->resize(1); outputs->resize(1);
// Concat all the preprocessed data to a batch tensor // Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size()); size_t tensor_size = end_index - start_index;
for (size_t i = 0; i < images->size(); ++i) { std::vector<FDTensor> tensors(tensor_size);
(*images)[i].ShareWithTensor(&(tensors[i])); for (size_t i = 0; i < tensor_size; ++i) {
(*images)[i + start_index].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0); tensors[i].ExpandDim(0);
} }
if (tensors.size() == 1) { if (tensors.size() == 1) {

View File

@@ -24,11 +24,6 @@ namespace ocr {
*/ */
class FASTDEPLOY_DECL ClassifierPreprocessor { class FASTDEPLOY_DECL ClassifierPreprocessor {
public: public:
/** \brief Create a preprocessor instance for Classifier serials model
*
*/
ClassifierPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input image data list, all the elements are returned by cv::imread() * \param[in] images The input image data list, all the elements are returned by cv::imread()
@@ -36,14 +31,13 @@ class FASTDEPLOY_DECL ClassifierPreprocessor {
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs); bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index);
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f}; std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f}; std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true; bool is_scale_ = true;
std::vector<int> cls_image_shape_ = {3, 48, 192}; std::vector<int> cls_image_shape_ = {3, 48, 192};
private:
bool initialized_ = false;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -50,14 +50,6 @@ bool DBDetector::Initialize() {
return true; return true;
} }
bool DBDetector::Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result) {
if (!Predict(*img, boxes_result)) {
return false;
}
return true;
}
bool DBDetector::Predict(const cv::Mat& img, bool DBDetector::Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result) { std::vector<std::array<int, 8>>* boxes_result) {
std::vector<std::vector<std::array<int, 8>>> det_results; std::vector<std::vector<std::array<int, 8>>> det_results;

View File

@@ -44,14 +44,6 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_det"; } std::string ModelName() const { return "ppocr/ocr_det"; }
/** \brief Predict the input image and get OCR detection model result.
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result);
/** \brief Predict the input image and get OCR detection model result. /** \brief Predict the input image and get OCR detection model result.
* *
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.

View File

@@ -20,10 +20,6 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
DBDetectorPostprocessor::DBDetectorPostprocessor() {
initialized_ = true;
}
bool DBDetectorPostprocessor::SingleBatchPostprocessor( bool DBDetectorPostprocessor::SingleBatchPostprocessor(
const float* out_data, const float* out_data,
int n2, int n2,
@@ -57,10 +53,10 @@ bool DBDetectorPostprocessor::SingleBatchPostprocessor(
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
boxes = boxes =
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_, util_post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_,
det_db_unclip_ratio_, det_db_score_mode_); det_db_unclip_ratio_, det_db_score_mode_);
boxes = post_processor_.FilterTagDetRes(boxes, det_img_info); boxes = util_post_processor_.FilterTagDetRes(boxes, det_img_info);
// boxes to boxes_result // boxes to boxes_result
for (int i = 0; i < boxes.size(); i++) { for (int i = 0; i < boxes.size(); i++) {
@@ -80,10 +76,6 @@ bool DBDetectorPostprocessor::SingleBatchPostprocessor(
bool DBDetectorPostprocessor::Run(const std::vector<FDTensor>& tensors, bool DBDetectorPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::vector<std::array<int, 8>>>* results, std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int,4>>& batch_det_img_info) { const std::vector<std::array<int,4>>& batch_det_img_info) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// DBDetector have only 1 output tensor. // DBDetector have only 1 output tensor.
const FDTensor& tensor = tensors[0]; const FDTensor& tensor = tensors[0];

View File

@@ -25,11 +25,6 @@ namespace ocr {
*/ */
class FASTDEPLOY_DECL DBDetectorPostprocessor { class FASTDEPLOY_DECL DBDetectorPostprocessor {
public: public:
/** \brief Create a postprocessor instance for DBDetector serials model
*
*/
DBDetectorPostprocessor();
/** \brief Process the result of runtime and fill to results structure /** \brief Process the result of runtime and fill to results structure
* *
* \param[in] tensors The inference result from runtime * \param[in] tensors The inference result from runtime
@@ -48,8 +43,7 @@ class FASTDEPLOY_DECL DBDetectorPostprocessor {
bool use_dilation_ = false; bool use_dilation_ = false;
private: private:
bool initialized_ = false; PostProcessor util_post_processor_;
PostProcessor post_processor_;
bool SingleBatchPostprocessor(const float* out_data, bool SingleBatchPostprocessor(const float* out_data,
int n2, int n2,
int n3, int n3,

View File

@@ -21,10 +21,6 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
DBDetectorPreprocessor::DBDetectorPreprocessor() {
initialized_ = true;
}
std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) { std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
int w = img->Width(); int w = img->Width();
int h = img->Height(); int h = img->Height();
@@ -63,10 +59,6 @@ bool OcrDetectorResizeImage(FDMat* img,
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images, bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs, std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) { std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) { if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl; FDERROR << "The size of input images should be greater than 0." << std::endl;
return false; return false;

View File

@@ -24,11 +24,6 @@ namespace ocr {
*/ */
class FASTDEPLOY_DECL DBDetectorPreprocessor { class FASTDEPLOY_DECL DBDetectorPreprocessor {
public: public:
/** \brief Create a preprocessor instance for DBDetector serials model
*
*/
DBDetectorPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input image data list, all the elements are returned by cv::imread() * \param[in] images The input image data list, all the elements are returned by cv::imread()
@@ -44,9 +39,6 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor {
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f}; std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f}; std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
bool is_scale_ = true; bool is_scale_ = true;
private:
bool initialized_ = false;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -20,14 +20,8 @@ void BindPPOCRModel(pybind11::module& m) {
vision::ocr::SortBoxes(&boxes); vision::ocr::SortBoxes(&boxes);
return boxes; return boxes;
}); });
// DBDetector // DBDetector
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_);
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor") pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_) .def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
@@ -45,7 +39,7 @@ void BindPPOCRModel(pybind11::module& m) {
for(size_t i = 0; i< outputs.size(); ++i){ for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing(); outputs[i].StopSharing();
} }
return make_pair(outputs, batch_det_img_info); return std::make_pair(outputs, batch_det_img_info);
}); });
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor") pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
@@ -77,15 +71,31 @@ void BindPPOCRModel(pybind11::module& m) {
return results; return results;
}); });
// Classifier pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
.def(pybind11::init<std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_) .def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_); .def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_)
.def("predict", [](vision::ocr::DBDetector& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
std::vector<std::array<int, 8>> boxes_result;
self.Predict(mat, &boxes_result);
return boxes_result;
})
.def("batch_predict", [](vision::ocr::DBDetector& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
std::vector<std::vector<std::array<int, 8>>> det_results;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
self.BatchPredict(images, &det_results);
return det_results;
});
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor") // Classifier
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_) .def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_) .def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
@@ -116,7 +126,7 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
} }
return make_pair(cls_labels,cls_scores); return std::make_pair(cls_labels,cls_scores);
}) })
.def("run", [](vision::ocr::ClassifierPostprocessor& self, .def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<pybind11::array>& input_array) { std::vector<pybind11::array>& input_array) {
@@ -127,39 +137,56 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &cls_labels, &cls_scores)) { if (!self.Run(inputs, &cls_labels, &cls_scores)) {
throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor.");
} }
return make_pair(cls_labels,cls_scores); return std::make_pair(cls_labels,cls_scores);
}); });
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
// Recognizer .def(pybind11::init<std::string, std::string, RuntimeOption,
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_) .def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_); .def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_)
.def("predict", [](vision::ocr::Classifier& self,
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor") pybind11::array& data) {
.def(pybind11::init<>()) auto mat = PyArrayToCvMat(data);
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_) int32_t cls_label;
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_) float cls_score;
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_) self.Predict(mat, &cls_label, &cls_score);
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_) return std::make_pair(cls_label, cls_score);
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) { })
std::vector<vision::FDMat> images; .def("batch_predict", [](vision::ocr::Classifier& self, std::vector<pybind11::array>& data) {
for (size_t i = 0; i < im_list.size(); ++i) { std::vector<cv::Mat> images;
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); std::vector<int32_t> cls_labels;
std::vector<float> cls_scores;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
} }
std::vector<FDTensor> outputs; self.BatchPredict(images, &cls_labels, &cls_scores);
if (!self.Run(&images, &outputs)) { return std::make_pair(cls_labels, cls_scores);
throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor.");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
}); });
// Recognizer
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
.def(pybind11::init<>())
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor.");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor") pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor")
.def(pybind11::init<std::string>()) .def(pybind11::init<std::string>())
.def("run", [](vision::ocr::RecognizerPostprocessor& self, .def("run", [](vision::ocr::RecognizerPostprocessor& self,
@@ -169,7 +196,7 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
} }
return make_pair(texts, rec_scores); return std::make_pair(texts, rec_scores);
}) })
.def("run", [](vision::ocr::RecognizerPostprocessor& self, .def("run", [](vision::ocr::RecognizerPostprocessor& self,
std::vector<pybind11::array>& input_array) { std::vector<pybind11::array>& input_array) {
@@ -180,7 +207,32 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &texts, &rec_scores)) { if (!self.Run(inputs, &texts, &rec_scores)) {
throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor.");
} }
return make_pair(texts, rec_scores); return std::make_pair(texts, rec_scores);
});
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_)
.def("predict", [](vision::ocr::Recognizer& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
std::string text;
float rec_score;
self.Predict(mat, &text, &rec_score);
return std::make_pair(text, rec_score);
})
.def("batch_predict", [](vision::ocr::Recognizer& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
std::vector<std::string> texts;
std::vector<float> rec_scores;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
self.BatchPredict(images, &texts, &rec_scores);
return std::make_pair(texts, rec_scores);
}); });
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -25,6 +25,8 @@ void BindPPOCRv3(pybind11::module& m) {
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*, .def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def_property("cls_batch_size", &pipeline::PPOCRv3::GetClsBatchSize, &pipeline::PPOCRv3::SetClsBatchSize)
.def_property("rec_batch_size", &pipeline::PPOCRv3::GetRecBatchSize, &pipeline::PPOCRv3::SetRecBatchSize)
.def("predict", [](pipeline::PPOCRv3& self, .def("predict", [](pipeline::PPOCRv3& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
@@ -52,6 +54,8 @@ void BindPPOCRv2(pybind11::module& m) {
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*, .def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def_property("cls_batch_size", &pipeline::PPOCRv2::GetClsBatchSize, &pipeline::PPOCRv2::SetClsBatchSize)
.def_property("rec_batch_size", &pipeline::PPOCRv2::GetRecBatchSize, &pipeline::PPOCRv2::SetRecBatchSize)
.def("predict", [](pipeline::PPOCRv2& self, .def("predict", [](pipeline::PPOCRv2& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);

View File

@@ -33,6 +33,32 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
recognizer_->preprocessor_.rec_image_shape_[1] = 32; recognizer_->preprocessor_.rec_image_shape_[1] = 32;
} }
bool PPOCRv2::SetClsBatchSize(int cls_batch_size) {
if (cls_batch_size < -1 || cls_batch_size == 0) {
FDERROR << "batch_size > 0 or batch_size == -1." << std::endl;
return false;
}
cls_batch_size_ = cls_batch_size;
return true;
}
int PPOCRv2::GetClsBatchSize() {
return cls_batch_size_;
}
bool PPOCRv2::SetRecBatchSize(int rec_batch_size) {
if (rec_batch_size < -1 || rec_batch_size == 0) {
FDERROR << "batch_size > 0 or batch_size == -1." << std::endl;
return false;
}
rec_batch_size_ = rec_batch_size;
return true;
}
int PPOCRv2::GetRecBatchSize() {
return rec_batch_size_;
}
bool PPOCRv2::Initialized() const { bool PPOCRv2::Initialized() const {
if (detector_ != nullptr && !detector_->Initialized()) { if (detector_ != nullptr && !detector_->Initialized()) {
@@ -52,7 +78,10 @@ bool PPOCRv2::Initialized() const {
bool PPOCRv2::Predict(cv::Mat* img, bool PPOCRv2::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) { fastdeploy::vision::OCRResult* result) {
std::vector<fastdeploy::vision::OCRResult> batch_result(1); std::vector<fastdeploy::vision::OCRResult> batch_result(1);
BatchPredict({*img},&batch_result); bool success = BatchPredict({*img},&batch_result);
if(!success){
return success;
}
*result = std::move(batch_result[0]); *result = std::move(batch_result[0]);
return true; return true;
}; };
@@ -67,12 +96,12 @@ bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
FDERROR << "There's error while detecting image in PPOCR." << std::endl; FDERROR << "There's error while detecting image in PPOCR." << std::endl;
return false; return false;
} }
for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) { for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
vision::ocr::SortBoxes(&(batch_boxes[i_batch])); vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
(*batch_result)[i_batch].boxes = batch_boxes[i_batch]; (*batch_result)[i_batch].boxes = batch_boxes[i_batch];
} }
for(int i_batch = 0; i_batch < images.size(); ++i_batch) { for(int i_batch = 0; i_batch < images.size(); ++i_batch) {
fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch]; fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch];
// Get croped images by detection result // Get croped images by detection result
@@ -93,22 +122,34 @@ bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* text_ptr = &ocr_result.text; std::vector<std::string>* text_ptr = &ocr_result.text;
std::vector<float>* rec_scores_ptr = &ocr_result.rec_scores; std::vector<float>* rec_scores_ptr = &ocr_result.rec_scores;
if (nullptr != classifier_){ if (nullptr != classifier_) {
if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) { for(size_t start_index = 0; start_index < image_list.size(); start_index+=cls_batch_size_) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl; size_t end_index = std::min(start_index + cls_batch_size_, image_list.size());
return false; if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr, start_index, end_index)) {
}else{ FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
for (size_t i_img = 0; i_img < image_list.size(); ++i_img) { return false;
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) { }else{
cv::rotate(image_list[i_img], image_list[i_img], 1); for (size_t i_img = start_index; i_img < end_index; ++i_img) {
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
cv::rotate(image_list[i_img], image_list[i_img], 1);
}
} }
} }
} }
} }
if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) { std::vector<float> width_list;
FDERROR << "There's error while recognizing image in PPOCR." << std::endl; for (int i = 0; i < image_list.size(); i++) {
return false; width_list.push_back(float(image_list[i].cols) / image_list[i].rows);
}
std::vector<int> indices = vision::ocr::ArgSort(width_list);
for(size_t start_index = 0; start_index < image_list.size(); start_index+=rec_batch_size_) {
size_t end_index = std::min(start_index + rec_batch_size_, image_list.size());
if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr, start_index, end_index, indices)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}
} }
} }
return true; return true;

View File

@@ -68,11 +68,19 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result); std::vector<fastdeploy::vision::OCRResult>* batch_result);
bool Initialized() const override; bool Initialized() const override;
bool SetClsBatchSize(int cls_batch_size);
int GetClsBatchSize();
bool SetRecBatchSize(int rec_batch_size);
int GetRecBatchSize();
protected: protected:
fastdeploy::vision::ocr::DBDetector* detector_ = nullptr; fastdeploy::vision::ocr::DBDetector* detector_ = nullptr;
fastdeploy::vision::ocr::Classifier* classifier_ = nullptr; fastdeploy::vision::ocr::Classifier* classifier_ = nullptr;
fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr; fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr;
private:
int cls_batch_size_ = 1;
int rec_batch_size_ = 6;
/// Launch the detection process in OCR. /// Launch the detection process in OCR.
}; };

View File

@@ -34,7 +34,7 @@ std::vector<std::string> ReadDict(const std::string& path) {
} }
RecognizerPostprocessor::RecognizerPostprocessor(){ RecognizerPostprocessor::RecognizerPostprocessor(){
initialized_ = true; initialized_ = false;
} }
RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) { RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) {
@@ -84,24 +84,53 @@ bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data,
bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors, bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores) { std::vector<std::string>* texts, std::vector<float>* rec_scores) {
// Recognizer have only 1 output tensor.
// For Recognizer, the output tensor shape = [batch, ?, 6625]
size_t total_size = tensors[0].shape[0];
return Run(tensors, texts, rec_scores, 0, total_size, {});
}
bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores,
size_t start_index, size_t total_size, const std::vector<int>& indices) {
if (!initialized_) { if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl; FDERROR << "Postprocessor is not initialized." << std::endl;
return false; return false;
} }
// Recognizer have only 1 output tensor. // Recognizer have only 1 output tensor.
const FDTensor& tensor = tensors[0]; const FDTensor& tensor = tensors[0];
// For Recognizer, the output tensor shape = [batch, ?, 6625] // For Recognizer, the output tensor shape = [batch, ?, 6625]
size_t batch = tensor.shape[0]; size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>()); size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
texts->resize(batch); if (batch <= 0) {
rec_scores->resize(batch); FDERROR << "The infer outputTensor.shape[0] <=0, wrong infer result." << std::endl;
return false;
}
if (start_index < 0 || total_size <= 0) {
FDERROR << "start_index or total_size error. Correct is: 0 <= start_index < total_size" << std::endl;
return false;
}
if ((start_index + batch) > total_size) {
FDERROR << "start_index or total_size error. Correct is: start_index + batch(outputTensor.shape[0]) <= total_size" << std::endl;
return false;
}
texts->resize(total_size);
rec_scores->resize(total_size);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data()); const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) { for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) { size_t real_index = i_batch+start_index;
if (indices.size() != 0) {
real_index = indices[i_batch+start_index];
}
if(!SingleBatchPostprocessor(tensor_data + i_batch * length,
tensor.shape,
&texts->at(real_index),
&rec_scores->at(real_index))) {
return false; return false;
} }
tensor_data = tensor_data + length;
} }
return true; return true;

View File

@@ -32,7 +32,7 @@ class FASTDEPLOY_DECL RecognizerPostprocessor {
*/ */
explicit RecognizerPostprocessor(const std::string& label_path); explicit RecognizerPostprocessor(const std::string& label_path);
/** \brief Process the result of runtime and fill to ClassifyResult structure /** \brief Process the result of runtime and fill to RecognizerResult
* *
* \param[in] tensors The inference result from runtime * \param[in] tensors The inference result from runtime
* \param[in] texts The output result of recognizer * \param[in] texts The output result of recognizer
@@ -42,6 +42,11 @@ class FASTDEPLOY_DECL RecognizerPostprocessor {
bool Run(const std::vector<FDTensor>& tensors, bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores); std::vector<std::string>* texts, std::vector<float>* rec_scores);
bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores,
size_t start_index, size_t total_size,
const std::vector<int>& indices);
private: private:
bool SingleBatchPostprocessor(const float* out_data, bool SingleBatchPostprocessor(const float* out_data,
const std::vector<int64_t>& output_shape, const std::vector<int64_t>& output_shape,

View File

@@ -21,69 +21,75 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
RecognizerPreprocessor::RecognizerPreprocessor() {
initialized_ = true;
}
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape) { const std::vector<int>& rec_image_shape) {
int imgC, imgH, imgW; int img_c, img_h, img_w;
imgC = rec_image_shape[0]; img_c = rec_image_shape[0];
imgH = rec_image_shape[1]; img_h = rec_image_shape[1];
imgW = rec_image_shape[2]; img_w = rec_image_shape[2];
imgW = int(imgH * max_wh_ratio); img_w = int(img_h * max_wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height()); float ratio = float(mat->Width()) / float(mat->Height());
int resize_w; int resize_w;
if (ceilf(imgH * ratio) > imgW) { if (ceilf(img_h * ratio) > img_w) {
resize_w = imgW; resize_w = img_w;
}else{ }else{
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(img_h * ratio));
} }
Resize::Run(mat, resize_w, imgH); Resize::Run(mat, resize_w, img_h);
std::vector<float> value = {0, 0, 0}; std::vector<float> value = {0, 0, 0};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value); Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), value);
} }
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) { bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) { return Run(images, outputs, 0, images->size(), {});
FDERROR << "The preprocessor is not initialized." << std::endl; }
return false;
} bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
if (images->size() == 0) { size_t start_index, size_t end_index, const std::vector<int>& indices) {
FDERROR << "The size of input images should be greater than 0." << std::endl; if (images->size() == 0 || end_index <= start_index || end_index > images->size()) {
FDERROR << "images->size() or index error. Correct is: 0 <= start_index < end_index <= images->size()" << std::endl;
return false; return false;
} }
int imgH = rec_image_shape_[1]; int img_h = rec_image_shape_[1];
int imgW = rec_image_shape_[2]; int img_w = rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH; float max_wh_ratio = img_w * 1.0 / img_h;
float ori_wh_ratio; float ori_wh_ratio;
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = start_index; i < end_index; ++i) {
FDMat* mat = &(images->at(i)); size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio); max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
} }
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = start_index; i < end_index; ++i) {
FDMat* mat = &(images->at(i)); size_t real_index = i;
if (indices.size() != 0) {
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_); OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
} }
// Only have 1 output Tensor. // Only have 1 output Tensor.
outputs->resize(1); outputs->resize(1);
size_t tensor_size = end_index-start_index;
// Concat all the preprocessed data to a batch tensor // Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size()); std::vector<FDTensor> tensors(tensor_size);
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = 0; i < tensor_size; ++i) {
(*images)[i].ShareWithTensor(&(tensors[i])); size_t real_index = i + start_index;
if (indices.size() != 0) {
real_index = indices[i + start_index];
}
(*images)[real_index].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0); tensors[i].ExpandDim(0);
} }
if (tensors.size() == 1) { if (tensors.size() == 1) {

View File

@@ -24,12 +24,6 @@ namespace ocr {
*/ */
class FASTDEPLOY_DECL RecognizerPreprocessor { class FASTDEPLOY_DECL RecognizerPreprocessor {
public: public:
/** \brief Create a preprocessor instance for PaddleClas serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
*/
RecognizerPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input image and prepare input tensors for runtime
* *
* \param[in] images The input image data list, all the elements are returned by cv::imread() * \param[in] images The input image data list, all the elements are returned by cv::imread()
@@ -37,14 +31,14 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
* \return true if the preprocess successed, otherwise false * \return true if the preprocess successed, otherwise false
*/ */
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs); bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index,
const std::vector<int>& indices);
std::vector<int> rec_image_shape_ = {3, 48, 320}; std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f}; std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f}; std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true; bool is_scale_ = true;
private:
bool initialized_ = false;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -53,10 +53,33 @@ bool Recognizer::Initialize() {
return true; return true;
} }
bool Recognizer::Predict(cv::Mat& img, std::string* text, float* rec_score) {
std::vector<std::string> texts(1);
std::vector<float> rec_scores(1);
bool success = BatchPredict({img}, &texts, &rec_scores);
if (!success) {
return success;
}
*text = std::move(texts[0]);
*rec_score = rec_scores[0];
return true;
}
bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images, bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores) { std::vector<std::string>* texts, std::vector<float>* rec_scores) {
return BatchPredict(images, texts, rec_scores, 0, images.size(), {});
}
bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores,
size_t start_index, size_t end_index, const std::vector<int>& indices) {
size_t total_size = images.size();
if (indices.size() != 0 && indices.size() != total_size) {
FDERROR << "indices.size() should be 0 or images.size()." << std::endl;
return false;
}
std::vector<FDMat> fd_images = WrapMat(images); std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index, indices)) {
FDERROR << "Failed to preprocess the input image." << std::endl; FDERROR << "Failed to preprocess the input image." << std::endl;
return false; return false;
} }
@@ -66,7 +89,7 @@ bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
return false; return false;
} }
if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) { if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores, start_index, total_size, indices)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false; return false;
} }

View File

@@ -45,6 +45,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_rec"; } std::string ModelName() const { return "ppocr/ocr_rec"; }
virtual bool Predict(cv::Mat& img, std::string* text, float* rec_score);
/** \brief BatchPredict the input image and get OCR recognition model result. /** \brief BatchPredict the input image and get OCR recognition model result.
* *
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -53,6 +54,10 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
*/ */
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores); std::vector<std::string>* texts, std::vector<float>* rec_scores);
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores,
size_t start_index, size_t end_index,
const std::vector<int>& indices);
RecognizerPreprocessor preprocessor_; RecognizerPreprocessor preprocessor_;
RecognizerPostprocessor postprocessor_; RecognizerPostprocessor postprocessor_;

2
fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h Normal file → Executable file
View File

@@ -33,6 +33,8 @@ FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes); FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes);
FASTDEPLOY_DECL std::vector<int> ArgSort(const std::vector<float> &array);
} // namespace ocr } // namespace ocr
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

13
fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc Normal file → Executable file
View File

@@ -44,6 +44,19 @@ void SortBoxes(std::vector<std::array<int, 8>>* boxes) {
} }
} }
std::vector<int> ArgSort(const std::vector<float> &array) {
const int array_len(array.size());
std::vector<int> array_index(array_len, 0);
for (int i = 0; i < array_len; ++i)
array_index[i] = i;
std::sort(
array_index.begin(), array_index.end(),
[&array](int pos1, int pos2) { return (array[pos1] < array[pos2]); });
return array_index;
}
} // namespace ocr } // namespace ocr
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -18,6 +18,134 @@ from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C from .... import c_lib_wrap as C
def sort_boxes(boxes):
return C.vision.ocr.sort_boxes(boxes)
class DBDetectorPreprocessor:
def __init__(self):
"""Create a preprocessor for DBDetectorModel
"""
self._preprocessor = C.vision.ocr.DBDetectorPreprocessor()
def run(self, input_ims):
"""Preprocess input images for DBDetectorModel
:param: input_ims: (list of numpy.ndarray) The input image
:return: pair(list of FDTensor, list of std::array<int, 4>)
"""
return self._preprocessor.run(input_ims)
@property
def max_side_len(self):
return self._preprocessor.max_side_len
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._preprocessor.max_side_len = value
@property
def is_scale(self):
return self._preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._preprocessor.is_scale = value
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
class DBDetectorPostprocessor:
def __init__(self):
"""Create a postprocessor for DBDetectorModel
"""
self._postprocessor = C.vision.ocr.DBDetectorPostprocessor()
def run(self, runtime_results, batch_det_img_info):
"""Postprocess the runtime results for DBDetectorModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:param: batch_det_img_info: (list of std::array<int, 4>)The output of det_preprocessor
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, batch_det_img_info)
@property
def det_db_thresh(self):
return self._postprocessor.det_db_thresh
@det_db_thresh.setter
def det_db_thresh(self, value):
assert isinstance(
value,
float), "The value to set `det_db_thresh` must be type of float."
self._postprocessor.det_db_thresh = value
@property
def det_db_box_thresh(self):
return self._postprocessor.det_db_box_thresh
@det_db_box_thresh.setter
def det_db_box_thresh(self, value):
assert isinstance(
value, float
), "The value to set `det_db_box_thresh` must be type of float."
self._postprocessor.det_db_box_thresh = value
@property
def det_db_unclip_ratio(self):
return self._postprocessor.det_db_unclip_ratio
@det_db_unclip_ratio.setter
def det_db_unclip_ratio(self, value):
assert isinstance(
value, float
), "The value to set `det_db_unclip_ratio` must be type of float."
self._postprocessor.det_db_unclip_ratio = value
@property
def det_db_score_mode(self):
return self._postprocessor.det_db_score_mode
@det_db_score_mode.setter
def det_db_score_mode(self, value):
assert isinstance(
value,
str), "The value to set `det_db_score_mode` must be type of str."
self._postprocessor.det_db_score_mode = value
@property
def use_dilation(self):
return self._postprocessor.use_dilation
@use_dilation.setter
def use_dilation(self, value):
assert isinstance(
value,
bool), "The value to set `use_dilation` must be type of bool."
self._postprocessor.use_dilation = value
class DBDetector(FastDeployModel): class DBDetector(FastDeployModel):
def __init__(self, def __init__(self,
model_file="", model_file="",
@@ -35,88 +163,223 @@ class DBDetector(FastDeployModel):
if (len(model_file) == 0): if (len(model_file) == 0):
self._model = C.vision.ocr.DBDetector() self._model = C.vision.ocr.DBDetector()
self._runnable = False
else: else:
self._model = C.vision.ocr.DBDetector( self._model = C.vision.ocr.DBDetector(
model_file, params_file, self._runtime_option, model_format) model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "DBDetector initialize failed." assert self.initialized, "DBDetector initialize failed."
self._runnable = True
# 一些跟DBDetector模型有关的属性封装 def predict(self, input_image):
''' """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: boxes
"""
if self._runnable:
return self._model.predict(input_image)
return False
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: batch_boxes
"""
if self._runnable:
return self._model.batch_predict(images)
return False
@property
def preprocessor(self):
return self._model.preprocessor
@preprocessor.setter
def preprocessor(self, value):
self._model.preprocessor = value
@property
def postprocessor(self):
return self._model.postprocessor
@postprocessor.setter
def postprocessor(self, value):
self._model.postprocessor = value
# Det Preprocessor Property
@property
def max_side_len(self):
return self._model.preprocessor.max_side_len
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.preprocessor.max_side_len = value
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
# Det Ppstprocessor Property
@property @property
def det_db_thresh(self): def det_db_thresh(self):
return self._model.det_db_thresh return self._model.postprocessor.det_db_thresh
@det_db_thresh.setter @det_db_thresh.setter
def det_db_thresh(self, value): def det_db_thresh(self, value):
assert isinstance( assert isinstance(
value, value,
float), "The value to set `det_db_thresh` must be type of float." float), "The value to set `det_db_thresh` must be type of float."
self._model.det_db_thresh = value self._model.postprocessor.det_db_thresh = value
@property @property
def det_db_box_thresh(self): def det_db_box_thresh(self):
return self._model.det_db_box_thresh return self._model.postprocessor.det_db_box_thresh
@det_db_box_thresh.setter @det_db_box_thresh.setter
def det_db_box_thresh(self, value): def det_db_box_thresh(self, value):
assert isinstance( assert isinstance(
value, float value, float
), "The value to set `det_db_box_thresh` must be type of float." ), "The value to set `det_db_box_thresh` must be type of float."
self._model.det_db_box_thresh = value self._model.postprocessor.det_db_box_thresh = value
@property @property
def det_db_unclip_ratio(self): def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio return self._model.postprocessor.det_db_unclip_ratio
@det_db_unclip_ratio.setter @det_db_unclip_ratio.setter
def det_db_unclip_ratio(self, value): def det_db_unclip_ratio(self, value):
assert isinstance( assert isinstance(
value, float value, float
), "The value to set `det_db_unclip_ratio` must be type of float." ), "The value to set `det_db_unclip_ratio` must be type of float."
self._model.det_db_unclip_ratio = value self._model.postprocessor.det_db_unclip_ratio = value
@property @property
def det_db_score_mode(self): def det_db_score_mode(self):
return self._model.det_db_score_mode return self._model.postprocessor.det_db_score_mode
@det_db_score_mode.setter @det_db_score_mode.setter
def det_db_score_mode(self, value): def det_db_score_mode(self, value):
assert isinstance( assert isinstance(
value, value,
str), "The value to set `det_db_score_mode` must be type of str." str), "The value to set `det_db_score_mode` must be type of str."
self._model.det_db_score_mode = value self._model.postprocessor.det_db_score_mode = value
@property @property
def use_dilation(self): def use_dilation(self):
return self._model.use_dilation return self._model.postprocessor.use_dilation
@use_dilation.setter @use_dilation.setter
def use_dilation(self, value): def use_dilation(self, value):
assert isinstance( assert isinstance(
value, value,
bool), "The value to set `use_dilation` must be type of bool." bool), "The value to set `use_dilation` must be type of bool."
self._model.use_dilation = value self._model.postprocessor.use_dilation = value
@property
def max_side_len(self):
return self._model.max_side_len
@max_side_len.setter class ClassifierPreprocessor:
def max_side_len(self, value): def __init__(self):
assert isinstance( """Create a preprocessor for ClassifierModel
value, int), "The value to set `max_side_len` must be type of int." """
self._model.max_side_len = value self._preprocessor = C.vision.ocr.ClassifierPreprocessor()
def run(self, input_ims):
"""Preprocess input images for ClassifierModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
@property @property
def is_scale(self): def is_scale(self):
return self._model.max_wh return self._preprocessor.is_scale
@is_scale.setter @is_scale.setter
def is_scale(self, value): def is_scale(self, value):
assert isinstance( assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool." value, bool), "The value to set `is_scale` must be type of bool."
self._model.is_scale = value self._preprocessor.is_scale = value
'''
@property
def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
@property
def cls_image_shape(self):
return self._preprocessor.cls_image_shape
@cls_image_shape.setter
def cls_image_shape(self, value):
assert isinstance(
value,
list), "The value to set `cls_image_shape` must be type of list."
self._preprocessor.cls_image_shape = value
class ClassifierPostprocessor:
def __init__(self):
"""Create a postprocessor for ClassifierModel
"""
self._postprocessor = C.vision.ocr.ClassifierPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for ClassifierModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
@property
def cls_thresh(self):
return self._postprocessor.cls_thresh
@cls_thresh.setter
def cls_thresh(self, value):
assert isinstance(
value,
float), "The value to set `cls_thresh` must be type of float."
self._postprocessor.cls_thresh = value
class Classifier(FastDeployModel): class Classifier(FastDeployModel):
@@ -136,44 +399,170 @@ class Classifier(FastDeployModel):
if (len(model_file) == 0): if (len(model_file) == 0):
self._model = C.vision.ocr.Classifier() self._model = C.vision.ocr.Classifier()
self._runnable = False
else: else:
self._model = C.vision.ocr.Classifier( self._model = C.vision.ocr.Classifier(
model_file, params_file, self._runtime_option, model_format) model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "Classifier initialize failed." assert self.initialized, "Classifier initialize failed."
self._runnable = True
def predict(self, input_image):
"""Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: cls_label, cls_score
"""
if self._runnable:
return self._model.predict(input_image)
return False
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: list of cls_label, list of cls_score
"""
if self._runnable:
return self._model.batch_predict(images)
return False
'''
@property @property
def cls_thresh(self): def preprocessor(self):
return self._model.cls_thresh return self._model.preprocessor
@preprocessor.setter
def preprocessor(self, value):
self._model.preprocessor = value
@property
def postprocessor(self):
return self._model.postprocessor
@postprocessor.setter
def postprocessor(self, value):
self._model.postprocessor = value
# Cls Preprocessor Property
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.preprocessor.is_scale = value
@property
def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._model.preprocessor.scale = value
@property
def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property @property
def cls_image_shape(self): def cls_image_shape(self):
return self._model.cls_image_shape return self._model.preprocessor.cls_image_shape
@cls_image_shape.setter
def cls_image_shape(self, value):
assert isinstance(
value,
list), "The value to set `cls_image_shape` must be type of list."
self._model.preprocessor.cls_image_shape = value
# Cls Postprocessor Property
@property @property
def cls_batch_num(self): def cls_thresh(self):
return self._model.cls_batch_num return self._model.postprocessor.cls_thresh
@cls_thresh.setter @cls_thresh.setter
def cls_thresh(self, value): def cls_thresh(self, value):
assert isinstance( assert isinstance(
value, value,
float), "The value to set `cls_thresh` must be type of float." float), "The value to set `cls_thresh` must be type of float."
self._model.cls_thresh = value self._model.postprocessor.cls_thresh = value
@cls_image_shape.setter
def cls_image_shape(self, value): class RecognizerPreprocessor:
def __init__(self):
"""Create a preprocessor for RecognizerModel
"""
self._preprocessor = C.vision.ocr.RecognizerPreprocessor()
def run(self, input_ims):
"""Preprocess input images for RecognizerModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
@property
def is_scale(self):
return self._preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance( assert isinstance(
value, list), "The value to set `cls_thresh` must be type of list." value, bool), "The value to set `is_scale` must be type of bool."
self._model.cls_image_shape = value self._preprocessor.is_scale = value
@cls_batch_num.setter @property
def cls_batch_num(self, value): def scale(self):
return self._preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance(
value, list), "The value to set `scale` must be type of list."
self._preprocessor.scale = value
@property
def mean(self):
return self._preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._preprocessor.mean = value
@property
def rec_image_shape(self):
return self._preprocessor.rec_image_shape
@rec_image_shape.setter
def rec_image_shape(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `cls_batch_num` must be type of int." list), "The value to set `rec_image_shape` must be type of list."
self._model.cls_batch_num = value self._preprocessor.rec_image_shape = value
'''
class RecognizerPostprocessor:
def __init__(self, label_path):
"""Create a postprocessor for RecognizerModel
:param label_path: (str)Path of label file
"""
self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path)
def run(self, runtime_results):
"""Postprocess the runtime results for RecognizerModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class Recognizer(FastDeployModel): class Recognizer(FastDeployModel):
@@ -195,44 +584,88 @@ class Recognizer(FastDeployModel):
if (len(model_file) == 0): if (len(model_file) == 0):
self._model = C.vision.ocr.Recognizer() self._model = C.vision.ocr.Recognizer()
self._runnable = False
else: else:
self._model = C.vision.ocr.Recognizer( self._model = C.vision.ocr.Recognizer(
model_file, params_file, label_path, self._runtime_option, model_file, params_file, label_path, self._runtime_option,
model_format) model_format)
assert self.initialized, "Recognizer initialize failed." assert self.initialized, "Recognizer initialize failed."
self._runnable = True
''' def predict(self, input_image):
@property """Predict an input image
def rec_img_h(self): :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
return self._model.rec_img_h :return: rec_text, rec_score
"""
if self._runnable:
return self._model.predict(input_image)
return False
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: list of rec_text, list of rec_score
"""
if self._runnable:
return self._model.batch_predict(images)
return False
@property @property
def rec_img_w(self): def preprocessor(self):
return self._model.rec_img_w return self._model.preprocessor
@preprocessor.setter
def preprocessor(self, value):
self._model.preprocessor = value
@property @property
def rec_batch_num(self): def postprocessor(self):
return self._model.rec_batch_num return self._model.postprocessor
@rec_img_h.setter @postprocessor.setter
def rec_img_h(self, value): def postprocessor(self, value):
self._model.postprocessor = value
@property
def is_scale(self):
return self._model.preprocessor.is_scale
@is_scale.setter
def is_scale(self, value):
assert isinstance( assert isinstance(
value, int), "The value to set `rec_img_h` must be type of int." value, bool), "The value to set `is_scale` must be type of bool."
self._model.rec_img_h = value self._model.preprocessor.is_scale = value
@rec_img_w.setter @property
def rec_img_w(self, value): def scale(self):
return self._model.preprocessor.scale
@scale.setter
def scale(self, value):
assert isinstance( assert isinstance(
value, int), "The value to set `rec_img_w` must be type of int." value, list), "The value to set `scale` must be type of list."
self._model.rec_img_w = value self._model.preprocessor.scale = value
@rec_batch_num.setter @property
def rec_batch_num(self, value): def mean(self):
return self._model.preprocessor.mean
@mean.setter
def mean(self, value):
assert isinstance(
value, list), "The value to set `mean` must be type of list."
self._model.preprocessor.mean = value
@property
def rec_image_shape(self):
return self._model.preprocessor.rec_image_shape
@rec_image_shape.setter
def rec_image_shape(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `rec_batch_num` must be type of int." list), "The value to set `rec_image_shape` must be type of list."
self._model.rec_batch_num = value self._model.preprocessor.rec_image_shape = value
'''
class PPOCRv3(FastDeployModel): class PPOCRv3(FastDeployModel):
@@ -253,7 +686,6 @@ class PPOCRv3(FastDeployModel):
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: OCRResult :return: OCRResult
""" """
@@ -264,9 +696,30 @@ class PPOCRv3(FastDeployModel):
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult :return: OCRBatchResult
""" """
return self.system.batch_predict(images) return self.system.batch_predict(images)
@property
def cls_batch_size(self):
return self.system.cls_batch_size
@cls_batch_size.setter
def cls_batch_size(self, value):
assert isinstance(
value,
int), "The value to set `cls_batch_size` must be type of int."
self.system.cls_batch_size = value
@property
def rec_batch_size(self):
return self.system.rec_batch_size
@rec_batch_size.setter
def rec_batch_size(self, value):
assert isinstance(
value,
int), "The value to set `rec_batch_size` must be type of int."
self.system.rec_batch_size = value
class PPOCRSystemv3(PPOCRv3): class PPOCRSystemv3(PPOCRv3):
def __init__(self, det_model=None, cls_model=None, rec_model=None): def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -311,6 +764,28 @@ class PPOCRv2(FastDeployModel):
return self.system.batch_predict(images) return self.system.batch_predict(images)
@property
def cls_batch_size(self):
return self.system.cls_batch_size
@cls_batch_size.setter
def cls_batch_size(self, value):
assert isinstance(
value,
int), "The value to set `cls_batch_size` must be type of int."
self.system.cls_batch_size = value
@property
def rec_batch_size(self):
return self.system.rec_batch_size
@rec_batch_size.setter
def rec_batch_size(self, value):
assert isinstance(
value,
int), "The value to set `rec_batch_size` must be type of int."
self.system.rec_batch_size = value
class PPOCRSystemv2(PPOCRv2): class PPOCRSystemv2(PPOCRv2):
def __init__(self, det_model=None, cls_model=None, rec_model=None): def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -321,93 +796,3 @@ class PPOCRSystemv2(PPOCRv2):
def predict(self, input_image): def predict(self, input_image):
return super(PPOCRSystemv2, self).predict(input_image) return super(PPOCRSystemv2, self).predict(input_image)
class DBDetectorPreprocessor:
def __init__(self):
"""Create a preprocessor for DBDetectorModel
"""
self._preprocessor = C.vision.ocr.DBDetectorPreprocessor()
def run(self, input_ims):
"""Preprocess input images for DBDetectorModel
:param: input_ims: (list of numpy.ndarray) The input image
:return: pair(list of FDTensor, list of std::array<int, 4>)
"""
return self._preprocessor.run(input_ims)
class DBDetectorPostprocessor:
def __init__(self):
"""Create a postprocessor for DBDetectorModel
"""
self._postprocessor = C.vision.ocr.DBDetectorPostprocessor()
def run(self, runtime_results, batch_det_img_info):
"""Postprocess the runtime results for DBDetectorModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:param: batch_det_img_info: (list of std::array<int, 4>)The output of det_preprocessor
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, batch_det_img_info)
class RecognizerPreprocessor:
def __init__(self):
"""Create a preprocessor for RecognizerModel
"""
self._preprocessor = C.vision.ocr.RecognizerPreprocessor()
def run(self, input_ims):
"""Preprocess input images for RecognizerModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class RecognizerPostprocessor:
def __init__(self, label_path):
"""Create a postprocessor for RecognizerModel
:param label_path: (str)Path of label file
"""
self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path)
def run(self, runtime_results):
"""Postprocess the runtime results for RecognizerModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class ClassifierPreprocessor:
def __init__(self):
"""Create a preprocessor for ClassifierModel
"""
self._preprocessor = C.vision.ocr.ClassifierPreprocessor()
def run(self, input_ims):
"""Preprocess input images for ClassifierModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class ClassifierPostprocessor:
def __init__(self):
"""Create a postprocessor for ClassifierModel
"""
self._postprocessor = C.vision.ocr.ClassifierPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for ClassifierModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
def sort_boxes(boxes):
return C.vision.ocr.sort_boxes(boxes)

256
tests/models/test_ppocrv3.py Executable file
View File

@@ -0,0 +1,256 @@
import fastdeploy as fd
import cv2
import os
import runtime_config as rc
import numpy as np
import math
import pickle
det_model_url = "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar"
cls_model_url = "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar"
rec_model_url = "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar"
img_url = "https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg"
label_url = "https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ocr_result.pickle"
fd.download_and_decompress(det_model_url, "resources")
fd.download_and_decompress(cls_model_url, "resources")
fd.download_and_decompress(rec_model_url, "resources")
fd.download(img_url, "resources")
fd.download(result_url, "resources")
fd.download(label_url, "resources")
def get_rotate_crop_image(img, box):
points = []
for i in range(4):
points.append([box[2 * i], box[2 * i + 1]])
points = np.array(points, dtype=np.float32)
img = img.astype(np.float32)
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
option = fd.RuntimeOption()
# det_model
det_model_path = "resources/ch_PP-OCRv3_det_infer/"
det_model_file = det_model_path + "inference.pdmodel"
det_params_file = det_model_path + "inference.pdiparams"
det_preprocessor = fd.vision.ocr.DBDetectorPreprocessor()
rc.test_option.set_model_path(det_model_file, det_params_file)
det_runtime = fd.Runtime(rc.test_option)
det_postprocessor = fd.vision.ocr.DBDetectorPostprocessor()
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=option)
# cls_model
cls_model_path = "resources/ch_ppocr_mobile_v2.0_cls_infer/"
cls_model_file = cls_model_path + "inference.pdmodel"
cls_params_file = cls_model_path + "inference.pdiparams"
cls_preprocessor = fd.vision.ocr.ClassifierPreprocessor()
rc.test_option.set_model_path(cls_model_file, cls_params_file)
cls_runtime = fd.Runtime(rc.test_option)
cls_postprocessor = fd.vision.ocr.ClassifierPostprocessor()
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=option)
#rec_model
rec_model_path = "resources/ch_PP-OCRv3_rec_infer/"
rec_model_file = rec_model_path + "inference.pdmodel"
rec_params_file = rec_model_path + "inference.pdiparams"
rec_label_file = "resources/ppocr_keys_v1.txt"
rec_preprocessor = fd.vision.ocr.RecognizerPreprocessor()
rc.test_option.set_model_path(rec_model_file, rec_params_file)
rec_runtime = fd.Runtime(rc.test_option)
rec_postprocessor = fd.vision.ocr.RecognizerPostprocessor(rec_label_file)
rec_model = fd.vision.ocr.Recognizer(
rec_model_file, rec_params_file, rec_label_file, runtime_option=option)
#pp_ocrv3
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
#pp_ocrv3_no_cls
ppocr_v3_no_cls = fd.vision.ocr.PPOCRv3(
det_model=det_model, rec_model=rec_model)
#input image
img_file = "resources/12.jpg"
im = []
im.append(cv2.imread(img_file))
im.append(cv2.imread(img_file))
result_file = "resources/ocr_result.pickle"
with open(result_file, 'rb') as f:
boxes, cls_labels, cls_scores, text, rec_scores = pickle.load(f)
base_boxes = np.array(boxes)
base_cls_labels = np.array(cls_labels)
base_cls_scores = np.array(cls_scores)
base_text = text
base_rec_scores = np.array(rec_scores)
def compare_result(pred_boxes, pred_cls_labels, pred_cls_scores, pred_text,
pred_rec_scores):
pred_boxes = np.array(pred_boxes)
pred_cls_labels = np.array(pred_cls_labels)
pred_cls_scores = np.array(pred_cls_scores)
pred_text = pred_text
pred_rec_scores = np.array(pred_rec_scores)
diff_boxes = np.fabs(base_boxes - pred_boxes).max()
diff_cls_labels = np.fabs(base_cls_labels - pred_cls_labels).max()
diff_cls_scores = np.fabs(base_cls_scores - pred_cls_scores).max()
diff_text = (base_text != pred_text)
diff_rec_scores = np.fabs(base_rec_scores - pred_rec_scores).max()
print('diff:', diff_boxes, diff_cls_labels, diff_cls_scores, diff_text,
diff_rec_scores)
diff_threshold = 1e-01
assert diff_boxes < diff_threshold, "There is diff in boxes"
assert diff_cls_labels < diff_threshold, "There is diff in cls_label"
assert diff_cls_scores < diff_threshold, "There is diff in cls_scores"
assert diff_text < diff_threshold, "There is diff in text"
assert diff_rec_scores < diff_threshold, "There is diff in rec_scores"
def compare_result_no_cls(pred_boxes, pred_text, pred_rec_scores):
pred_boxes = np.array(pred_boxes)
pred_text = pred_text
pred_rec_scores = np.array(pred_rec_scores)
diff_boxes = np.fabs(base_boxes - pred_boxes).max()
diff_text = (base_text != pred_text)
diff_rec_scores = np.fabs(base_rec_scores - pred_rec_scores).max()
print('diff:', diff_boxes, diff_text, diff_rec_scores)
diff_threshold = 1e-01
assert diff_boxes < diff_threshold, "There is diff in boxes"
assert diff_text < diff_threshold, "There is diff in text"
assert diff_rec_scores < diff_threshold, "There is diff in rec_scores"
def test_ppocr_v3():
ppocr_v3.cls_batch_size = -1
ppocr_v3.rec_batch_size = -1
ocr_result = ppocr_v3.predict(im[0])
compare_result(ocr_result.boxes, ocr_result.cls_labels,
ocr_result.cls_scores, ocr_result.text,
ocr_result.rec_scores)
ppocr_v3.cls_batch_size = 2
ppocr_v3.rec_batch_size = 2
ocr_result = ppocr_v3.predict(im[0])
compare_result(ocr_result.boxes, ocr_result.cls_labels,
ocr_result.cls_scores, ocr_result.text,
ocr_result.rec_scores)
def test_ppocr_v3_1():
ppocr_v3_no_cls.cls_batch_size = -1
ppocr_v3_no_cls.rec_batch_size = -1
ocr_result = ppocr_v3_no_cls.predict(im[0])
compare_result_no_cls(ocr_result.boxes, ocr_result.text,
ocr_result.rec_scores)
ppocr_v3_no_cls.cls_batch_size = 2
ppocr_v3_no_cls.rec_batch_size = 2
ocr_result = ppocr_v3_no_cls.predict(im[0])
compare_result_no_cls(ocr_result.boxes, ocr_result.text,
ocr_result.rec_scores)
def test_ppocr_v3_2():
det_input_tensors, batch_det_img_info = det_preprocessor.run(im)
det_output_tensors = det_runtime.infer({"x": det_input_tensors[0]})
det_results = det_postprocessor.run(det_output_tensors, batch_det_img_info)
batch_boxes = []
batch_cls_labels = []
batch_cls_scores = []
batch_rec_texts = []
batch_rec_scores = []
for i_batch in range(len(det_results)):
cls_labels = []
cls_scores = []
rec_texts = []
rec_scores = []
box_list = fd.vision.ocr.sort_boxes(det_results[i_batch])
batch_boxes.append(box_list)
image_list = []
if len(box_list) == 0:
image_list.append(im[i_batch])
else:
for box in box_list:
crop_img = get_rotate_crop_image(im[i_batch], box)
image_list.append(crop_img)
cls_input_tensors = cls_preprocessor.run(image_list)
cls_output_tensors = cls_runtime.infer({"x": cls_input_tensors[0]})
cls_labels, cls_scores = cls_postprocessor.run(cls_output_tensors)
batch_cls_labels.append(cls_labels)
batch_cls_scores.append(cls_scores)
for index in range(len(image_list)):
if cls_labels[index] == 1 and cls_scores[
index] > cls_postprocessor.cls_thresh:
image_list[index] = cv2.rotate(
image_list[index].astype(np.float32), 1)
image_list[index] = np.astype(np.uint8)
rec_input_tensors = rec_preprocessor.run(image_list)
rec_output_tensors = rec_runtime.infer({"x": rec_input_tensors[0]})
rec_texts, rec_scores = rec_postprocessor.run(rec_output_tensors)
batch_rec_texts.append(rec_texts)
batch_rec_scores.append(rec_scores)
compare_result(box_list, cls_labels, cls_scores, rec_texts, rec_scores)
if __name__ == "__main__":
print("test test_ppocr_v3")
test_ppocr_v3()
test_ppocr_v3()
print("test test_ppocr_v3_1")
test_ppocr_v3_1()
test_ppocr_v3_1()
print("test test_ppocr_v3_2")
test_ppocr_v3_2()
test_ppocr_v3_2()