diff --git a/cmake/paddle2onnx.cmake b/cmake/paddle2onnx.cmake index 9c0d4cb1e..baaac8759 100755 --- a/cmake/paddle2onnx.cmake +++ b/cmake/paddle2onnx.cmake @@ -43,7 +43,7 @@ else() endif(WIN32) set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") -set(PADDLE2ONNX_VERSION "1.0.4") +set(PADDLE2ONNX_VERSION "1.0.5") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") if(NOT CMAKE_CL_64) diff --git a/examples/vision/ocr/PP-OCRv3/serving/README.md b/examples/vision/ocr/PP-OCRv3/serving/README.md old mode 100644 new mode 100755 index 0762a2359..31db5a72c --- a/examples/vision/ocr/PP-OCRv3/serving/README.md +++ b/examples/vision/ocr/PP-OCRv3/serving/README.md @@ -47,6 +47,8 @@ tar xvf ch_PP-OCRv3_rec_infer.tar && mv ch_PP-OCRv3_rec_infer 1 mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel mv 1 models/rec_runtime/ && rm -rf ch_PP-OCRv3_rec_infer.tar +mkdir models/pp_ocr/1 && mkdir models/rec_pp/1 && mkdir models/cls_pp/1 + wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt mv ppocr_keys_v1.txt models/rec_postprocess/1/ diff --git a/examples/vision/ocr/PP-OCRv3/serving/client.py b/examples/vision/ocr/PP-OCRv3/serving/client.py old mode 100644 new mode 100755 index b7b20644f..1b150b7d0 --- a/examples/vision/ocr/PP-OCRv3/serving/client.py +++ b/examples/vision/ocr/PP-OCRv3/serving/client.py @@ -91,7 +91,7 @@ class SyncGRPCTritonRunner: if __name__ == "__main__": model_name = "pp_ocr" model_version = "1" - url = "localhost:9001" + url = "localhost:8001" runner = SyncGRPCTritonRunner(url, model_name, model_version) im = cv2.imread("12.jpg") im = np.array([im, ]) diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/cls_runtime/config.pbtxt b/examples/vision/ocr/PP-OCRv3/serving/models/cls_runtime/config.pbtxt old mode 100644 new mode 100755 index aa66b1f9b..eb7b25503 --- a/examples/vision/ocr/PP-OCRv3/serving/models/cls_runtime/config.pbtxt +++ b/examples/vision/ocr/PP-OCRv3/serving/models/cls_runtime/config.pbtxt @@ -35,3 +35,18 @@ instance_group [ gpus: [0] } ] + +optimization { + execution_accelerators { + # GPU推理配置, 配合KIND_GPU使用 + gpu_execution_accelerator : [ + { + name : "paddle" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + # 开启mkldnn加速,设置为0关闭mkldnn + parameters { key: "use_mkldnn" value: "1" } + } + ] + } +} diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/det_runtime/config.pbtxt b/examples/vision/ocr/PP-OCRv3/serving/models/det_runtime/config.pbtxt old mode 100644 new mode 100755 index c3ffe15c1..96d85e3e1 --- a/examples/vision/ocr/PP-OCRv3/serving/models/det_runtime/config.pbtxt +++ b/examples/vision/ocr/PP-OCRv3/serving/models/det_runtime/config.pbtxt @@ -35,3 +35,18 @@ instance_group [ gpus: [0] } ] + +optimization { + execution_accelerators { + # GPU推理配置, 配合KIND_GPU使用 + gpu_execution_accelerator : [ + { + name : "paddle" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + # 开启mkldnn加速,设置为0关闭mkldnn + parameters { key: "use_mkldnn" value: "1" } + } + ] + } +} \ No newline at end of file diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py b/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py old mode 100644 new mode 100755 index 049173382..fe66e8c3f --- a/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py +++ b/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py @@ -47,8 +47,6 @@ class TritonPythonModel: * model_version: Model version * model_name: Model name """ - sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach()) - print(sys.getdefaultencoding()) # You must parse model_config. JSON string is not parsed here self.model_config = json.loads(args['model_config']) print("model_config:", self.model_config) diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/rec_runtime/config.pbtxt b/examples/vision/ocr/PP-OCRv3/serving/models/rec_runtime/config.pbtxt old mode 100644 new mode 100755 index d4b3b1212..037d7a9f2 --- a/examples/vision/ocr/PP-OCRv3/serving/models/rec_runtime/config.pbtxt +++ b/examples/vision/ocr/PP-OCRv3/serving/models/rec_runtime/config.pbtxt @@ -35,3 +35,18 @@ instance_group [ gpus: [0] } ] + +optimization { + execution_accelerators { + # GPU推理配置, 配合KIND_GPU使用 + gpu_execution_accelerator : [ + { + name : "paddle" + # 设置推理并行计算线程数为4 + parameters { key: "cpu_threads" value: "4" } + # 开启mkldnn加速,设置为0关闭mkldnn + parameters { key: "use_mkldnn" value: "1" } + } + ] + } +} \ No newline at end of file diff --git a/fastdeploy/vision/ocr/ppocr/classifier.cc b/fastdeploy/vision/ocr/ppocr/classifier.cc index 130329735..4be9a3556 100755 --- a/fastdeploy/vision/ocr/ppocr/classifier.cc +++ b/fastdeploy/vision/ocr/ppocr/classifier.cc @@ -50,10 +50,29 @@ bool Classifier::Initialize() { return true; } +bool Classifier::Predict(cv::Mat& img, int32_t* cls_label, float* cls_score) { + std::vector cls_labels(1); + std::vector cls_scores(1); + bool success = BatchPredict({img}, &cls_labels, &cls_scores); + if(!success){ + return success; + } + *cls_label = cls_labels[0]; + *cls_score = cls_scores[0]; + return true; +} + bool Classifier::BatchPredict(const std::vector& images, std::vector* cls_labels, std::vector* cls_scores) { + return BatchPredict(images, cls_labels, cls_scores, 0, images.size()); +} + +bool Classifier::BatchPredict(const std::vector& images, + std::vector* cls_labels, std::vector* cls_scores, + size_t start_index, size_t end_index) { + size_t total_size = images.size(); std::vector fd_images = WrapMat(images); - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index)) { FDERROR << "Failed to preprocess the input image." << std::endl; return false; } @@ -63,7 +82,7 @@ bool Classifier::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) { + if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores, start_index, total_size)) { FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; return false; } diff --git a/fastdeploy/vision/ocr/ppocr/classifier.h b/fastdeploy/vision/ocr/ppocr/classifier.h index d3430e4e0..ddc4db27a 100755 --- a/fastdeploy/vision/ocr/ppocr/classifier.h +++ b/fastdeploy/vision/ocr/ppocr/classifier.h @@ -43,7 +43,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name std::string ModelName() const { return "ppocr/ocr_cls"; } - + virtual bool Predict(cv::Mat& img, int32_t* cls_label, float* cls_score); /** \brief BatchPredict the input image and get OCR classification model cls_result. * * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. @@ -53,6 +53,10 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { virtual bool BatchPredict(const std::vector& images, std::vector* cls_labels, std::vector* cls_scores); + virtual bool BatchPredict(const std::vector& images, + std::vector* cls_labels, + std::vector* cls_scores, + size_t start_index, size_t end_index); ClassifierPreprocessor preprocessor_; ClassifierPostprocessor postprocessor_; diff --git a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc index 5eb6b5d69..c19bcc3a6 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc @@ -20,10 +20,6 @@ namespace fastdeploy { namespace vision { namespace ocr { -ClassifierPostprocessor::ClassifierPostprocessor() { - initialized_ = true; -} - bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) { *cls_label = std::distance( @@ -37,10 +33,14 @@ bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* bool ClassifierPostprocessor::Run(const std::vector& tensors, std::vector* cls_labels, std::vector* cls_scores) { - if (!initialized_) { - FDERROR << "Postprocessor is not initialized." << std::endl; - return false; - } + size_t total_size = tensors[0].shape[0]; + return Run(tensors, cls_labels, cls_scores, 0, total_size); +} + +bool ClassifierPostprocessor::Run(const std::vector& tensors, + std::vector* cls_labels, + std::vector* cls_scores, + size_t start_index, size_t total_size) { // Classifier have only 1 output tensor. const FDTensor& tensor = tensors[0]; @@ -48,13 +48,29 @@ bool ClassifierPostprocessor::Run(const std::vector& tensors, size_t batch = tensor.shape[0]; size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies()); - cls_labels->resize(batch); - cls_scores->resize(batch); + if (batch <= 0) { + FDERROR << "The infer outputTensor.shape[0] <=0, wrong infer result." << std::endl; + return false; + } + if (start_index < 0 || total_size <= 0) { + FDERROR << "start_index or total_size error. Correct is: 0 <= start_index < total_size" << std::endl; + return false; + } + if ((start_index + batch) > total_size) { + FDERROR << "start_index or total_size error. Correct is: start_index + batch(outputTensor.shape[0]) <= total_size" << std::endl; + return false; + } + + cls_labels->resize(total_size); + cls_scores->resize(total_size); const float* tensor_data = reinterpret_cast(tensor.Data()); - for (int i_batch = 0; i_batch < batch; ++i_batch) { - if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false; - tensor_data = tensor_data + length; + if(!SingleBatchPostprocessor(tensor_data+ i_batch * length, + length, + &cls_labels->at(i_batch + start_index), + &cls_scores->at(i_batch + start_index))) { + return false; + } } return true; diff --git a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h index 15bf098c7..a755e1294 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h @@ -25,11 +25,6 @@ namespace ocr { */ class FASTDEPLOY_DECL ClassifierPostprocessor { public: - /** \brief Create a postprocessor instance for Classifier serials model - * - */ - ClassifierPostprocessor(); - /** \brief Process the result of runtime and fill to ClassifyResult structure * * \param[in] tensors The inference result from runtime @@ -40,10 +35,11 @@ class FASTDEPLOY_DECL ClassifierPostprocessor { bool Run(const std::vector& tensors, std::vector* cls_labels, std::vector* cls_scores); - float cls_thresh_ = 0.9; + bool Run(const std::vector& tensors, + std::vector* cls_labels, std::vector* cls_scores, + size_t start_index, size_t total_size); - private: - bool initialized_ = false; + float cls_thresh_ = 0.9; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc old mode 100644 new mode 100755 index 1f0993690..dcd76c168 --- a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc @@ -21,58 +21,53 @@ namespace fastdeploy { namespace vision { namespace ocr { -ClassifierPreprocessor::ClassifierPreprocessor() { - initialized_ = true; -} - void OcrClassifierResizeImage(FDMat* mat, const std::vector& cls_image_shape) { - int imgC = cls_image_shape[0]; - int imgH = cls_image_shape[1]; - int imgW = cls_image_shape[2]; + int img_c = cls_image_shape[0]; + int img_h = cls_image_shape[1]; + int img_w = cls_image_shape[2]; float ratio = float(mat->Width()) / float(mat->Height()); int resize_w; - if (ceilf(imgH * ratio) > imgW) - resize_w = imgW; + if (ceilf(img_h * ratio) > img_w) + resize_w = img_w; else - resize_w = int(ceilf(imgH * ratio)); + resize_w = int(ceilf(img_h * ratio)); - Resize::Run(mat, resize_w, imgH); - - std::vector value = {0, 0, 0}; - if (resize_w < imgW) { - Pad::Run(mat, 0, 0, 0, imgW - resize_w, value); - } + Resize::Run(mat, resize_w, img_h); } bool ClassifierPreprocessor::Run(std::vector* images, std::vector* outputs) { - if (!initialized_) { - FDERROR << "The preprocessor is not initialized." << std::endl; - return false; - } - if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." << std::endl; + return Run(images, outputs, 0, images->size()); +} + +bool ClassifierPreprocessor::Run(std::vector* images, std::vector* outputs, + size_t start_index, size_t end_index) { + + if (images->size() == 0 || start_index <0 || end_index <= start_index || end_index > images->size()) { + FDERROR << "images->size() or index error. Correct is: 0 <= start_index < end_index <= images->size()" << std::endl; return false; } - for (size_t i = 0; i < images->size(); ++i) { + for (size_t i = start_index; i < end_index; ++i) { FDMat* mat = &(images->at(i)); OcrClassifierResizeImage(mat, cls_image_shape_); - NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); - /* Normalize::Run(mat, mean_, scale_, is_scale_); + std::vector value = {0, 0, 0}; + if (mat->Width() < cls_image_shape_[2]) { + Pad::Run(mat, 0, 0, 0, cls_image_shape_[2] - mat->Width(), value); + } HWC2CHW::Run(mat); Cast::Run(mat, "float"); - */ } // Only have 1 output Tensor. outputs->resize(1); // Concat all the preprocessed data to a batch tensor - std::vector tensors(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - (*images)[i].ShareWithTensor(&(tensors[i])); + size_t tensor_size = end_index - start_index; + std::vector tensors(tensor_size); + for (size_t i = 0; i < tensor_size; ++i) { + (*images)[i + start_index].ShareWithTensor(&(tensors[i])); tensors[i].ExpandDim(0); } if (tensors.size() == 1) { diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h index a701e7e3a..ed75d55b2 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h @@ -24,11 +24,6 @@ namespace ocr { */ class FASTDEPLOY_DECL ClassifierPreprocessor { public: - /** \brief Create a preprocessor instance for Classifier serials model - * - */ - ClassifierPreprocessor(); - /** \brief Process the input image and prepare input tensors for runtime * * \param[in] images The input image data list, all the elements are returned by cv::imread() @@ -36,14 +31,13 @@ class FASTDEPLOY_DECL ClassifierPreprocessor { * \return true if the preprocess successed, otherwise false */ bool Run(std::vector* images, std::vector* outputs); + bool Run(std::vector* images, std::vector* outputs, + size_t start_index, size_t end_index); std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {0.5f, 0.5f, 0.5f}; bool is_scale_ = true; std::vector cls_image_shape_ = {3, 48, 192}; - - private: - bool initialized_ = false; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.cc b/fastdeploy/vision/ocr/ppocr/dbdetector.cc index 68a994afc..b490e88d9 100755 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.cc +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.cc @@ -50,14 +50,6 @@ bool DBDetector::Initialize() { return true; } -bool DBDetector::Predict(cv::Mat* img, - std::vector>* boxes_result) { - if (!Predict(*img, boxes_result)) { - return false; - } - return true; -} - bool DBDetector::Predict(const cv::Mat& img, std::vector>* boxes_result) { std::vector>> det_results; diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.h b/fastdeploy/vision/ocr/ppocr/dbdetector.h index d3b99d598..d2305abd7 100755 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.h +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.h @@ -44,14 +44,6 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name std::string ModelName() const { return "ppocr/ocr_det"; } - /** \brief Predict the input image and get OCR detection model result. - * - * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. - * \param[in] boxes_result The output of OCR detection model result will be writen to this structure. - * \return true if the prediction is successed, otherwise false. - */ - virtual bool Predict(cv::Mat* img, - std::vector>* boxes_result); /** \brief Predict the input image and get OCR detection model result. * * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. diff --git a/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc index 34a88c011..e83dac5e5 100644 --- a/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc @@ -20,10 +20,6 @@ namespace fastdeploy { namespace vision { namespace ocr { -DBDetectorPostprocessor::DBDetectorPostprocessor() { - initialized_ = true; -} - bool DBDetectorPostprocessor::SingleBatchPostprocessor( const float* out_data, int n2, @@ -57,10 +53,10 @@ bool DBDetectorPostprocessor::SingleBatchPostprocessor( std::vector>> boxes; boxes = - post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_, + util_post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_, det_db_unclip_ratio_, det_db_score_mode_); - boxes = post_processor_.FilterTagDetRes(boxes, det_img_info); + boxes = util_post_processor_.FilterTagDetRes(boxes, det_img_info); // boxes to boxes_result for (int i = 0; i < boxes.size(); i++) { @@ -80,10 +76,6 @@ bool DBDetectorPostprocessor::SingleBatchPostprocessor( bool DBDetectorPostprocessor::Run(const std::vector& tensors, std::vector>>* results, const std::vector>& batch_det_img_info) { - if (!initialized_) { - FDERROR << "Postprocessor is not initialized." << std::endl; - return false; - } // DBDetector have only 1 output tensor. const FDTensor& tensor = tensors[0]; diff --git a/fastdeploy/vision/ocr/ppocr/det_postprocessor.h b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h index f98b89b02..115228843 100644 --- a/fastdeploy/vision/ocr/ppocr/det_postprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h @@ -25,11 +25,6 @@ namespace ocr { */ class FASTDEPLOY_DECL DBDetectorPostprocessor { public: - /** \brief Create a postprocessor instance for DBDetector serials model - * - */ - DBDetectorPostprocessor(); - /** \brief Process the result of runtime and fill to results structure * * \param[in] tensors The inference result from runtime @@ -48,8 +43,7 @@ class FASTDEPLOY_DECL DBDetectorPostprocessor { bool use_dilation_ = false; private: - bool initialized_ = false; - PostProcessor post_processor_; + PostProcessor util_post_processor_; bool SingleBatchPostprocessor(const float* out_data, int n2, int n3, diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc index 89a8d6d39..28b7e47af 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc @@ -21,10 +21,6 @@ namespace fastdeploy { namespace vision { namespace ocr { -DBDetectorPreprocessor::DBDetectorPreprocessor() { - initialized_ = true; -} - std::array OcrDetectorGetInfo(FDMat* img, int max_size_len) { int w = img->Width(); int h = img->Height(); @@ -63,10 +59,6 @@ bool OcrDetectorResizeImage(FDMat* img, bool DBDetectorPreprocessor::Run(std::vector* images, std::vector* outputs, std::vector>* batch_det_img_info_ptr) { - if (!initialized_) { - FDERROR << "The preprocessor is not initialized." << std::endl; - return false; - } if (images->size() == 0) { FDERROR << "The size of input images should be greater than 0." << std::endl; return false; diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h index 39c48691d..d66e785d3 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h @@ -24,11 +24,6 @@ namespace ocr { */ class FASTDEPLOY_DECL DBDetectorPreprocessor { public: - /** \brief Create a preprocessor instance for DBDetector serials model - * - */ - DBDetectorPreprocessor(); - /** \brief Process the input image and prepare input tensors for runtime * * \param[in] images The input image data list, all the elements are returned by cv::imread() @@ -44,9 +39,6 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor { std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {0.229f, 0.224f, 0.225f}; bool is_scale_ = true; - - private: - bool initialized_ = false; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc index b7e7dcbf3..f84e99ed5 100755 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -20,14 +20,8 @@ void BindPPOCRModel(pybind11::module& m) { vision::ocr::SortBoxes(&boxes); return boxes; }); + // DBDetector - pybind11::class_(m, "DBDetector") - .def(pybind11::init()) - .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_); - pybind11::class_(m, "DBDetectorPreprocessor") .def(pybind11::init<>()) .def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_) @@ -45,7 +39,7 @@ void BindPPOCRModel(pybind11::module& m) { for(size_t i = 0; i< outputs.size(); ++i){ outputs[i].StopSharing(); } - return make_pair(outputs, batch_det_img_info); + return std::make_pair(outputs, batch_det_img_info); }); pybind11::class_(m, "DBDetectorPostprocessor") @@ -77,15 +71,31 @@ void BindPPOCRModel(pybind11::module& m) { return results; }); - // Classifier - pybind11::class_(m, "Classifier") + pybind11::class_(m, "DBDetector") .def(pybind11::init()) .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_); + .def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_) + .def("predict", [](vision::ocr::DBDetector& self, + pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + std::vector> boxes_result; + self.Predict(mat, &boxes_result); + return boxes_result; + }) + .def("batch_predict", [](vision::ocr::DBDetector& self, std::vector& data) { + std::vector images; + std::vector>> det_results; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + self.BatchPredict(images, &det_results); + return det_results; + }); - pybind11::class_(m, "ClassifierPreprocessor") + // Classifier + pybind11::class_(m, "ClassifierPreprocessor") .def(pybind11::init<>()) .def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_) .def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_) @@ -116,7 +126,7 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &cls_labels, &cls_scores)) { throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); } - return make_pair(cls_labels,cls_scores); + return std::make_pair(cls_labels,cls_scores); }) .def("run", [](vision::ocr::ClassifierPostprocessor& self, std::vector& input_array) { @@ -127,39 +137,56 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &cls_labels, &cls_scores)) { throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); } - return make_pair(cls_labels,cls_scores); + return std::make_pair(cls_labels,cls_scores); }); - - - // Recognizer - pybind11::class_(m, "Recognizer") - .def(pybind11::init(m, "Classifier") + .def(pybind11::init()) .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_); - - pybind11::class_(m, "RecognizerPreprocessor") - .def(pybind11::init<>()) - .def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_) - .def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_) - .def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_) - .def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_) - .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { - std::vector images; - for (size_t i = 0; i < im_list.size(); ++i) { - images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + .def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_) + .def("predict", [](vision::ocr::Classifier& self, + pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + int32_t cls_label; + float cls_score; + self.Predict(mat, &cls_label, &cls_score); + return std::make_pair(cls_label, cls_score); + }) + .def("batch_predict", [](vision::ocr::Classifier& self, std::vector& data) { + std::vector images; + std::vector cls_labels; + std::vector cls_scores; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); } - std::vector outputs; - if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor."); - } - for(size_t i = 0; i< outputs.size(); ++i){ - outputs[i].StopSharing(); - } - return outputs; + self.BatchPredict(images, &cls_labels, &cls_scores); + return std::make_pair(cls_labels, cls_scores); }); + // Recognizer + pybind11::class_(m, "RecognizerPreprocessor") + .def(pybind11::init<>()) + .def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_) + .def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_) + .def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_) + .def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_) + .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor."); + } + for(size_t i = 0; i< outputs.size(); ++i){ + outputs[i].StopSharing(); + } + return outputs; + }); + pybind11::class_(m, "RecognizerPostprocessor") .def(pybind11::init()) .def("run", [](vision::ocr::RecognizerPostprocessor& self, @@ -169,7 +196,7 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &texts, &rec_scores)) { throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); } - return make_pair(texts, rec_scores); + return std::make_pair(texts, rec_scores); }) .def("run", [](vision::ocr::RecognizerPostprocessor& self, std::vector& input_array) { @@ -180,7 +207,32 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &texts, &rec_scores)) { throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); } - return make_pair(texts, rec_scores); + return std::make_pair(texts, rec_scores); + }); + + pybind11::class_(m, "Recognizer") + .def(pybind11::init()) + .def(pybind11::init<>()) + .def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_) + .def("predict", [](vision::ocr::Recognizer& self, + pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + std::string text; + float rec_score; + self.Predict(mat, &text, &rec_score); + return std::make_pair(text, rec_score); + }) + .def("batch_predict", [](vision::ocr::Recognizer& self, std::vector& data) { + std::vector images; + std::vector texts; + std::vector rec_scores; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + self.BatchPredict(images, &texts, &rec_scores); + return std::make_pair(texts, rec_scores); }); } } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc b/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc index fcbbe0224..7542793fe 100755 --- a/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc @@ -25,6 +25,8 @@ void BindPPOCRv3(pybind11::module& m) { fastdeploy::vision::ocr::Recognizer*>()) .def(pybind11::init()) + .def_property("cls_batch_size", &pipeline::PPOCRv3::GetClsBatchSize, &pipeline::PPOCRv3::SetClsBatchSize) + .def_property("rec_batch_size", &pipeline::PPOCRv3::GetRecBatchSize, &pipeline::PPOCRv3::SetRecBatchSize) .def("predict", [](pipeline::PPOCRv3& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); @@ -52,6 +54,8 @@ void BindPPOCRv2(pybind11::module& m) { fastdeploy::vision::ocr::Recognizer*>()) .def(pybind11::init()) + .def_property("cls_batch_size", &pipeline::PPOCRv2::GetClsBatchSize, &pipeline::PPOCRv2::SetClsBatchSize) + .def_property("rec_batch_size", &pipeline::PPOCRv2::GetRecBatchSize, &pipeline::PPOCRv2::SetRecBatchSize) .def("predict", [](pipeline::PPOCRv2& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc index daa40692d..2ee2f903f 100755 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc @@ -33,6 +33,32 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, recognizer_->preprocessor_.rec_image_shape_[1] = 32; } +bool PPOCRv2::SetClsBatchSize(int cls_batch_size) { + if (cls_batch_size < -1 || cls_batch_size == 0) { + FDERROR << "batch_size > 0 or batch_size == -1." << std::endl; + return false; + } + cls_batch_size_ = cls_batch_size; + return true; +} + +int PPOCRv2::GetClsBatchSize() { + return cls_batch_size_; +} + +bool PPOCRv2::SetRecBatchSize(int rec_batch_size) { + if (rec_batch_size < -1 || rec_batch_size == 0) { + FDERROR << "batch_size > 0 or batch_size == -1." << std::endl; + return false; + } + rec_batch_size_ = rec_batch_size; + return true; +} + +int PPOCRv2::GetRecBatchSize() { + return rec_batch_size_; +} + bool PPOCRv2::Initialized() const { if (detector_ != nullptr && !detector_->Initialized()) { @@ -52,7 +78,10 @@ bool PPOCRv2::Initialized() const { bool PPOCRv2::Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result) { std::vector batch_result(1); - BatchPredict({*img},&batch_result); + bool success = BatchPredict({*img},&batch_result); + if(!success){ + return success; + } *result = std::move(batch_result[0]); return true; }; @@ -67,12 +96,12 @@ bool PPOCRv2::BatchPredict(const std::vector& images, FDERROR << "There's error while detecting image in PPOCR." << std::endl; return false; } + for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) { vision::ocr::SortBoxes(&(batch_boxes[i_batch])); (*batch_result)[i_batch].boxes = batch_boxes[i_batch]; } - for(int i_batch = 0; i_batch < images.size(); ++i_batch) { fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch]; // Get croped images by detection result @@ -93,22 +122,34 @@ bool PPOCRv2::BatchPredict(const std::vector& images, std::vector* text_ptr = &ocr_result.text; std::vector* rec_scores_ptr = &ocr_result.rec_scores; - if (nullptr != classifier_){ - if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) { - FDERROR << "There's error while recognizing image in PPOCR." << std::endl; - return false; - }else{ - for (size_t i_img = 0; i_img < image_list.size(); ++i_img) { - if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) { - cv::rotate(image_list[i_img], image_list[i_img], 1); + if (nullptr != classifier_) { + for(size_t start_index = 0; start_index < image_list.size(); start_index+=cls_batch_size_) { + size_t end_index = std::min(start_index + cls_batch_size_, image_list.size()); + if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr, start_index, end_index)) { + FDERROR << "There's error while recognizing image in PPOCR." << std::endl; + return false; + }else{ + for (size_t i_img = start_index; i_img < end_index; ++i_img) { + if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) { + cv::rotate(image_list[i_img], image_list[i_img], 1); + } } } } } - - if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) { - FDERROR << "There's error while recognizing image in PPOCR." << std::endl; - return false; + + std::vector width_list; + for (int i = 0; i < image_list.size(); i++) { + width_list.push_back(float(image_list[i].cols) / image_list[i].rows); + } + std::vector indices = vision::ocr::ArgSort(width_list); + + for(size_t start_index = 0; start_index < image_list.size(); start_index+=rec_batch_size_) { + size_t end_index = std::min(start_index + rec_batch_size_, image_list.size()); + if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr, start_index, end_index, indices)) { + FDERROR << "There's error while recognizing image in PPOCR." << std::endl; + return false; + } } } return true; diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v2.h b/fastdeploy/vision/ocr/ppocr/ppocr_v2.h index d021d6c32..05f2b9309 100755 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v2.h +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v2.h @@ -68,11 +68,19 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel { virtual bool BatchPredict(const std::vector& images, std::vector* batch_result); bool Initialized() const override; + bool SetClsBatchSize(int cls_batch_size); + int GetClsBatchSize(); + bool SetRecBatchSize(int rec_batch_size); + int GetRecBatchSize(); protected: fastdeploy::vision::ocr::DBDetector* detector_ = nullptr; fastdeploy::vision::ocr::Classifier* classifier_ = nullptr; fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr; + + private: + int cls_batch_size_ = 1; + int rec_batch_size_ = 6; /// Launch the detection process in OCR. }; diff --git a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc index cdc302e28..d93c16907 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc @@ -34,7 +34,7 @@ std::vector ReadDict(const std::string& path) { } RecognizerPostprocessor::RecognizerPostprocessor(){ - initialized_ = true; + initialized_ = false; } RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) { @@ -84,24 +84,53 @@ bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data, bool RecognizerPostprocessor::Run(const std::vector& tensors, std::vector* texts, std::vector* rec_scores) { + // Recognizer have only 1 output tensor. + // For Recognizer, the output tensor shape = [batch, ?, 6625] + size_t total_size = tensors[0].shape[0]; + return Run(tensors, texts, rec_scores, 0, total_size, {}); +} + +bool RecognizerPostprocessor::Run(const std::vector& tensors, + std::vector* texts, std::vector* rec_scores, + size_t start_index, size_t total_size, const std::vector& indices) { if (!initialized_) { FDERROR << "Postprocessor is not initialized." << std::endl; return false; } + // Recognizer have only 1 output tensor. const FDTensor& tensor = tensors[0]; // For Recognizer, the output tensor shape = [batch, ?, 6625] size_t batch = tensor.shape[0]; size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies()); - texts->resize(batch); - rec_scores->resize(batch); + if (batch <= 0) { + FDERROR << "The infer outputTensor.shape[0] <=0, wrong infer result." << std::endl; + return false; + } + if (start_index < 0 || total_size <= 0) { + FDERROR << "start_index or total_size error. Correct is: 0 <= start_index < total_size" << std::endl; + return false; + } + if ((start_index + batch) > total_size) { + FDERROR << "start_index or total_size error. Correct is: start_index + batch(outputTensor.shape[0]) <= total_size" << std::endl; + return false; + } + texts->resize(total_size); + rec_scores->resize(total_size); + const float* tensor_data = reinterpret_cast(tensor.Data()); - for (int i_batch = 0; i_batch < batch; ++i_batch) { - if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) { + for (int i_batch = 0; i_batch < batch; ++i_batch) { + size_t real_index = i_batch+start_index; + if (indices.size() != 0) { + real_index = indices[i_batch+start_index]; + } + if(!SingleBatchPostprocessor(tensor_data + i_batch * length, + tensor.shape, + &texts->at(real_index), + &rec_scores->at(real_index))) { return false; } - tensor_data = tensor_data + length; } return true; diff --git a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h index d1aa0124b..711ae3a01 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h @@ -32,7 +32,7 @@ class FASTDEPLOY_DECL RecognizerPostprocessor { */ explicit RecognizerPostprocessor(const std::string& label_path); - /** \brief Process the result of runtime and fill to ClassifyResult structure + /** \brief Process the result of runtime and fill to RecognizerResult * * \param[in] tensors The inference result from runtime * \param[in] texts The output result of recognizer @@ -42,6 +42,11 @@ class FASTDEPLOY_DECL RecognizerPostprocessor { bool Run(const std::vector& tensors, std::vector* texts, std::vector* rec_scores); + bool Run(const std::vector& tensors, + std::vector* texts, std::vector* rec_scores, + size_t start_index, size_t total_size, + const std::vector& indices); + private: bool SingleBatchPostprocessor(const float* out_data, const std::vector& output_shape, diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc index 858578d69..a965eb762 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc @@ -21,69 +21,75 @@ namespace fastdeploy { namespace vision { namespace ocr { -RecognizerPreprocessor::RecognizerPreprocessor() { - initialized_ = true; -} - void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, const std::vector& rec_image_shape) { - int imgC, imgH, imgW; - imgC = rec_image_shape[0]; - imgH = rec_image_shape[1]; - imgW = rec_image_shape[2]; + int img_c, img_h, img_w; + img_c = rec_image_shape[0]; + img_h = rec_image_shape[1]; + img_w = rec_image_shape[2]; - imgW = int(imgH * max_wh_ratio); + img_w = int(img_h * max_wh_ratio); float ratio = float(mat->Width()) / float(mat->Height()); int resize_w; - if (ceilf(imgH * ratio) > imgW) { - resize_w = imgW; + if (ceilf(img_h * ratio) > img_w) { + resize_w = img_w; }else{ - resize_w = int(ceilf(imgH * ratio)); + resize_w = int(ceilf(img_h * ratio)); } - Resize::Run(mat, resize_w, imgH); + Resize::Run(mat, resize_w, img_h); std::vector value = {0, 0, 0}; - Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value); + Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), value); } bool RecognizerPreprocessor::Run(std::vector* images, std::vector* outputs) { - if (!initialized_) { - FDERROR << "The preprocessor is not initialized." << std::endl; - return false; - } - if (images->size() == 0) { - FDERROR << "The size of input images should be greater than 0." << std::endl; + return Run(images, outputs, 0, images->size(), {}); +} + +bool RecognizerPreprocessor::Run(std::vector* images, std::vector* outputs, + size_t start_index, size_t end_index, const std::vector& indices) { + if (images->size() == 0 || end_index <= start_index || end_index > images->size()) { + FDERROR << "images->size() or index error. Correct is: 0 <= start_index < end_index <= images->size()" << std::endl; return false; } - int imgH = rec_image_shape_[1]; - int imgW = rec_image_shape_[2]; - float max_wh_ratio = imgW * 1.0 / imgH; + int img_h = rec_image_shape_[1]; + int img_w = rec_image_shape_[2]; + float max_wh_ratio = img_w * 1.0 / img_h; float ori_wh_ratio; - for (size_t i = 0; i < images->size(); ++i) { - FDMat* mat = &(images->at(i)); + for (size_t i = start_index; i < end_index; ++i) { + size_t real_index = i; + if (indices.size() != 0) { + real_index = indices[i]; + } + FDMat* mat = &(images->at(real_index)); ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio); } - for (size_t i = 0; i < images->size(); ++i) { - FDMat* mat = &(images->at(i)); + for (size_t i = start_index; i < end_index; ++i) { + size_t real_index = i; + if (indices.size() != 0) { + real_index = indices[i]; + } + FDMat* mat = &(images->at(real_index)); OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_); NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); - /* - Normalize::Run(mat, mean_, scale_, is_scale_); - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - */ } // Only have 1 output Tensor. outputs->resize(1); + size_t tensor_size = end_index-start_index; // Concat all the preprocessed data to a batch tensor - std::vector tensors(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - (*images)[i].ShareWithTensor(&(tensors[i])); + std::vector tensors(tensor_size); + for (size_t i = 0; i < tensor_size; ++i) { + size_t real_index = i + start_index; + if (indices.size() != 0) { + real_index = indices[i + start_index]; + } + + (*images)[real_index].ShareWithTensor(&(tensors[i])); tensors[i].ExpandDim(0); } if (tensors.size() == 1) { diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h index 3e5c7de82..1dad75870 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h @@ -24,12 +24,6 @@ namespace ocr { */ class FASTDEPLOY_DECL RecognizerPreprocessor { public: - /** \brief Create a preprocessor instance for PaddleClas serials model - * - * \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml - */ - RecognizerPreprocessor(); - /** \brief Process the input image and prepare input tensors for runtime * * \param[in] images The input image data list, all the elements are returned by cv::imread() @@ -37,14 +31,14 @@ class FASTDEPLOY_DECL RecognizerPreprocessor { * \return true if the preprocess successed, otherwise false */ bool Run(std::vector* images, std::vector* outputs); + bool Run(std::vector* images, std::vector* outputs, + size_t start_index, size_t end_index, + const std::vector& indices); std::vector rec_image_shape_ = {3, 48, 320}; std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {0.5f, 0.5f, 0.5f}; bool is_scale_ = true; - - private: - bool initialized_ = false; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.cc b/fastdeploy/vision/ocr/ppocr/recognizer.cc index 59d64a6c4..a20f312c2 100755 --- a/fastdeploy/vision/ocr/ppocr/recognizer.cc +++ b/fastdeploy/vision/ocr/ppocr/recognizer.cc @@ -53,10 +53,33 @@ bool Recognizer::Initialize() { return true; } +bool Recognizer::Predict(cv::Mat& img, std::string* text, float* rec_score) { + std::vector texts(1); + std::vector rec_scores(1); + bool success = BatchPredict({img}, &texts, &rec_scores); + if (!success) { + return success; + } + *text = std::move(texts[0]); + *rec_score = rec_scores[0]; + return true; +} + bool Recognizer::BatchPredict(const std::vector& images, std::vector* texts, std::vector* rec_scores) { + return BatchPredict(images, texts, rec_scores, 0, images.size(), {}); +} + +bool Recognizer::BatchPredict(const std::vector& images, + std::vector* texts, std::vector* rec_scores, + size_t start_index, size_t end_index, const std::vector& indices) { + size_t total_size = images.size(); + if (indices.size() != 0 && indices.size() != total_size) { + FDERROR << "indices.size() should be 0 or images.size()." << std::endl; + return false; + } std::vector fd_images = WrapMat(images); - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index, indices)) { FDERROR << "Failed to preprocess the input image." << std::endl; return false; } @@ -66,7 +89,7 @@ bool Recognizer::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) { + if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores, start_index, total_size, indices)) { FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; return false; } diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.h b/fastdeploy/vision/ocr/ppocr/recognizer.h index 1cd841eb4..4ee12bb6a 100755 --- a/fastdeploy/vision/ocr/ppocr/recognizer.h +++ b/fastdeploy/vision/ocr/ppocr/recognizer.h @@ -45,6 +45,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name std::string ModelName() const { return "ppocr/ocr_rec"; } + virtual bool Predict(cv::Mat& img, std::string* text, float* rec_score); /** \brief BatchPredict the input image and get OCR recognition model result. * * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. @@ -53,6 +54,10 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { */ virtual bool BatchPredict(const std::vector& images, std::vector* texts, std::vector* rec_scores); + virtual bool BatchPredict(const std::vector& images, + std::vector* texts, std::vector* rec_scores, + size_t start_index, size_t end_index, + const std::vector& indices); RecognizerPreprocessor preprocessor_; RecognizerPostprocessor postprocessor_; diff --git a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h old mode 100644 new mode 100755 index 0e5c040eb..f12f40f71 --- a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h +++ b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h @@ -33,6 +33,8 @@ FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage, FASTDEPLOY_DECL void SortBoxes(std::vector>* boxes); +FASTDEPLOY_DECL std::vector ArgSort(const std::vector &array); + } // namespace ocr } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc b/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc old mode 100644 new mode 100755 index b55d2a9eb..b0a909f9a --- a/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc +++ b/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc @@ -44,6 +44,19 @@ void SortBoxes(std::vector>* boxes) { } } +std::vector ArgSort(const std::vector &array) { + const int array_len(array.size()); + std::vector array_index(array_len, 0); + for (int i = 0; i < array_len; ++i) + array_index[i] = i; + + std::sort( + array_index.begin(), array_index.end(), + [&array](int pos1, int pos2) { return (array[pos1] < array[pos2]); }); + + return array_index; +} + } // namespace ocr } // namespace vision } // namespace fastdeploy diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index b8f5c81d1..dde3f807b 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -18,6 +18,134 @@ from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C +def sort_boxes(boxes): + return C.vision.ocr.sort_boxes(boxes) + + +class DBDetectorPreprocessor: + def __init__(self): + """Create a preprocessor for DBDetectorModel + """ + self._preprocessor = C.vision.ocr.DBDetectorPreprocessor() + + def run(self, input_ims): + """Preprocess input images for DBDetectorModel + :param: input_ims: (list of numpy.ndarray) The input image + :return: pair(list of FDTensor, list of std::array) + """ + return self._preprocessor.run(input_ims) + + @property + def max_side_len(self): + return self._preprocessor.max_side_len + + @max_side_len.setter + def max_side_len(self, value): + assert isinstance( + value, int), "The value to set `max_side_len` must be type of int." + self._preprocessor.max_side_len = value + + @property + def is_scale(self): + return self._preprocessor.is_scale + + @is_scale.setter + def is_scale(self, value): + assert isinstance( + value, bool), "The value to set `is_scale` must be type of bool." + self._preprocessor.is_scale = value + + @property + def scale(self): + return self._preprocessor.scale + + @scale.setter + def scale(self, value): + assert isinstance( + value, list), "The value to set `scale` must be type of list." + self._preprocessor.scale = value + + @property + def mean(self): + return self._preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._preprocessor.mean = value + + +class DBDetectorPostprocessor: + def __init__(self): + """Create a postprocessor for DBDetectorModel + """ + self._postprocessor = C.vision.ocr.DBDetectorPostprocessor() + + def run(self, runtime_results, batch_det_img_info): + """Postprocess the runtime results for DBDetectorModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :param: batch_det_img_info: (list of std::array)The output of det_preprocessor + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results, batch_det_img_info) + + @property + def det_db_thresh(self): + return self._postprocessor.det_db_thresh + + @det_db_thresh.setter + def det_db_thresh(self, value): + assert isinstance( + value, + float), "The value to set `det_db_thresh` must be type of float." + self._postprocessor.det_db_thresh = value + + @property + def det_db_box_thresh(self): + return self._postprocessor.det_db_box_thresh + + @det_db_box_thresh.setter + def det_db_box_thresh(self, value): + assert isinstance( + value, float + ), "The value to set `det_db_box_thresh` must be type of float." + self._postprocessor.det_db_box_thresh = value + + @property + def det_db_unclip_ratio(self): + return self._postprocessor.det_db_unclip_ratio + + @det_db_unclip_ratio.setter + def det_db_unclip_ratio(self, value): + assert isinstance( + value, float + ), "The value to set `det_db_unclip_ratio` must be type of float." + self._postprocessor.det_db_unclip_ratio = value + + @property + def det_db_score_mode(self): + return self._postprocessor.det_db_score_mode + + @det_db_score_mode.setter + def det_db_score_mode(self, value): + assert isinstance( + value, + str), "The value to set `det_db_score_mode` must be type of str." + self._postprocessor.det_db_score_mode = value + + @property + def use_dilation(self): + return self._postprocessor.use_dilation + + @use_dilation.setter + def use_dilation(self, value): + assert isinstance( + value, + bool), "The value to set `use_dilation` must be type of bool." + self._postprocessor.use_dilation = value + + class DBDetector(FastDeployModel): def __init__(self, model_file="", @@ -35,88 +163,223 @@ class DBDetector(FastDeployModel): if (len(model_file) == 0): self._model = C.vision.ocr.DBDetector() + self._runnable = False else: self._model = C.vision.ocr.DBDetector( model_file, params_file, self._runtime_option, model_format) assert self.initialized, "DBDetector initialize failed." + self._runnable = True - # 一些跟DBDetector模型有关的属性封装 - ''' + def predict(self, input_image): + """Predict an input image + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: boxes + """ + if self._runnable: + return self._model.predict(input_image) + return False + + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: batch_boxes + """ + if self._runnable: + return self._model.batch_predict(images) + return False + + @property + def preprocessor(self): + return self._model.preprocessor + + @preprocessor.setter + def preprocessor(self, value): + self._model.preprocessor = value + + @property + def postprocessor(self): + return self._model.postprocessor + + @postprocessor.setter + def postprocessor(self, value): + self._model.postprocessor = value + + # Det Preprocessor Property + @property + def max_side_len(self): + return self._model.preprocessor.max_side_len + + @max_side_len.setter + def max_side_len(self, value): + assert isinstance( + value, int), "The value to set `max_side_len` must be type of int." + self._model.preprocessor.max_side_len = value + + @property + def is_scale(self): + return self._model.preprocessor.is_scale + + @is_scale.setter + def is_scale(self, value): + assert isinstance( + value, bool), "The value to set `is_scale` must be type of bool." + self._model.preprocessor.is_scale = value + + @property + def scale(self): + return self._model.preprocessor.scale + + @scale.setter + def scale(self, value): + assert isinstance( + value, list), "The value to set `scale` must be type of list." + self._model.preprocessor.scale = value + + @property + def mean(self): + return self._model.preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._model.preprocessor.mean = value + + # Det Ppstprocessor Property @property def det_db_thresh(self): - return self._model.det_db_thresh + return self._model.postprocessor.det_db_thresh @det_db_thresh.setter def det_db_thresh(self, value): assert isinstance( value, float), "The value to set `det_db_thresh` must be type of float." - self._model.det_db_thresh = value + self._model.postprocessor.det_db_thresh = value @property def det_db_box_thresh(self): - return self._model.det_db_box_thresh + return self._model.postprocessor.det_db_box_thresh @det_db_box_thresh.setter def det_db_box_thresh(self, value): assert isinstance( value, float ), "The value to set `det_db_box_thresh` must be type of float." - self._model.det_db_box_thresh = value + self._model.postprocessor.det_db_box_thresh = value @property def det_db_unclip_ratio(self): - return self._model.det_db_unclip_ratio + return self._model.postprocessor.det_db_unclip_ratio @det_db_unclip_ratio.setter def det_db_unclip_ratio(self, value): assert isinstance( value, float ), "The value to set `det_db_unclip_ratio` must be type of float." - self._model.det_db_unclip_ratio = value + self._model.postprocessor.det_db_unclip_ratio = value @property def det_db_score_mode(self): - return self._model.det_db_score_mode + return self._model.postprocessor.det_db_score_mode @det_db_score_mode.setter def det_db_score_mode(self, value): assert isinstance( value, str), "The value to set `det_db_score_mode` must be type of str." - self._model.det_db_score_mode = value + self._model.postprocessor.det_db_score_mode = value @property def use_dilation(self): - return self._model.use_dilation + return self._model.postprocessor.use_dilation @use_dilation.setter def use_dilation(self, value): assert isinstance( value, bool), "The value to set `use_dilation` must be type of bool." - self._model.use_dilation = value + self._model.postprocessor.use_dilation = value - @property - def max_side_len(self): - return self._model.max_side_len - @max_side_len.setter - def max_side_len(self, value): - assert isinstance( - value, int), "The value to set `max_side_len` must be type of int." - self._model.max_side_len = value +class ClassifierPreprocessor: + def __init__(self): + """Create a preprocessor for ClassifierModel + """ + self._preprocessor = C.vision.ocr.ClassifierPreprocessor() + + def run(self, input_ims): + """Preprocess input images for ClassifierModel + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) @property def is_scale(self): - return self._model.max_wh + return self._preprocessor.is_scale @is_scale.setter def is_scale(self, value): assert isinstance( value, bool), "The value to set `is_scale` must be type of bool." - self._model.is_scale = value - ''' + self._preprocessor.is_scale = value + + @property + def scale(self): + return self._preprocessor.scale + + @scale.setter + def scale(self, value): + assert isinstance( + value, list), "The value to set `scale` must be type of list." + self._preprocessor.scale = value + + @property + def mean(self): + return self._preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._preprocessor.mean = value + + @property + def cls_image_shape(self): + return self._preprocessor.cls_image_shape + + @cls_image_shape.setter + def cls_image_shape(self, value): + assert isinstance( + value, + list), "The value to set `cls_image_shape` must be type of list." + self._preprocessor.cls_image_shape = value + + +class ClassifierPostprocessor: + def __init__(self): + """Create a postprocessor for ClassifierModel + """ + self._postprocessor = C.vision.ocr.ClassifierPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for ClassifierModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + @property + def cls_thresh(self): + return self._postprocessor.cls_thresh + + @cls_thresh.setter + def cls_thresh(self, value): + assert isinstance( + value, + float), "The value to set `cls_thresh` must be type of float." + self._postprocessor.cls_thresh = value class Classifier(FastDeployModel): @@ -136,44 +399,170 @@ class Classifier(FastDeployModel): if (len(model_file) == 0): self._model = C.vision.ocr.Classifier() + self._runnable = False else: self._model = C.vision.ocr.Classifier( model_file, params_file, self._runtime_option, model_format) assert self.initialized, "Classifier initialize failed." + self._runnable = True + + def predict(self, input_image): + """Predict an input image + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: cls_label, cls_score + """ + if self._runnable: + return self._model.predict(input_image) + return False + + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: list of cls_label, list of cls_score + """ + if self._runnable: + return self._model.batch_predict(images) + return False - ''' @property - def cls_thresh(self): - return self._model.cls_thresh + def preprocessor(self): + return self._model.preprocessor + + @preprocessor.setter + def preprocessor(self, value): + self._model.preprocessor = value + + @property + def postprocessor(self): + return self._model.postprocessor + + @postprocessor.setter + def postprocessor(self, value): + self._model.postprocessor = value + + # Cls Preprocessor Property + @property + def is_scale(self): + return self._model.preprocessor.is_scale + + @is_scale.setter + def is_scale(self, value): + assert isinstance( + value, bool), "The value to set `is_scale` must be type of bool." + self._model.preprocessor.is_scale = value + + @property + def scale(self): + return self._model.preprocessor.scale + + @scale.setter + def scale(self, value): + assert isinstance( + value, list), "The value to set `scale` must be type of list." + self._model.preprocessor.scale = value + + @property + def mean(self): + return self._model.preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._model.preprocessor.mean = value @property def cls_image_shape(self): - return self._model.cls_image_shape + return self._model.preprocessor.cls_image_shape + @cls_image_shape.setter + def cls_image_shape(self, value): + assert isinstance( + value, + list), "The value to set `cls_image_shape` must be type of list." + self._model.preprocessor.cls_image_shape = value + + # Cls Postprocessor Property @property - def cls_batch_num(self): - return self._model.cls_batch_num + def cls_thresh(self): + return self._model.postprocessor.cls_thresh @cls_thresh.setter def cls_thresh(self, value): assert isinstance( value, float), "The value to set `cls_thresh` must be type of float." - self._model.cls_thresh = value + self._model.postprocessor.cls_thresh = value - @cls_image_shape.setter - def cls_image_shape(self, value): + +class RecognizerPreprocessor: + def __init__(self): + """Create a preprocessor for RecognizerModel + """ + self._preprocessor = C.vision.ocr.RecognizerPreprocessor() + + def run(self, input_ims): + """Preprocess input images for RecognizerModel + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + @property + def is_scale(self): + return self._preprocessor.is_scale + + @is_scale.setter + def is_scale(self, value): assert isinstance( - value, list), "The value to set `cls_thresh` must be type of list." - self._model.cls_image_shape = value + value, bool), "The value to set `is_scale` must be type of bool." + self._preprocessor.is_scale = value - @cls_batch_num.setter - def cls_batch_num(self, value): + @property + def scale(self): + return self._preprocessor.scale + + @scale.setter + def scale(self, value): + assert isinstance( + value, list), "The value to set `scale` must be type of list." + self._preprocessor.scale = value + + @property + def mean(self): + return self._preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._preprocessor.mean = value + + @property + def rec_image_shape(self): + return self._preprocessor.rec_image_shape + + @rec_image_shape.setter + def rec_image_shape(self, value): assert isinstance( value, - int), "The value to set `cls_batch_num` must be type of int." - self._model.cls_batch_num = value - ''' + list), "The value to set `rec_image_shape` must be type of list." + self._preprocessor.rec_image_shape = value + + +class RecognizerPostprocessor: + def __init__(self, label_path): + """Create a postprocessor for RecognizerModel + :param label_path: (str)Path of label file + """ + self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path) + + def run(self, runtime_results): + """Postprocess the runtime results for RecognizerModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) class Recognizer(FastDeployModel): @@ -195,44 +584,88 @@ class Recognizer(FastDeployModel): if (len(model_file) == 0): self._model = C.vision.ocr.Recognizer() + self._runnable = False else: self._model = C.vision.ocr.Recognizer( model_file, params_file, label_path, self._runtime_option, model_format) assert self.initialized, "Recognizer initialize failed." + self._runnable = True - ''' - @property - def rec_img_h(self): - return self._model.rec_img_h + def predict(self, input_image): + """Predict an input image + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: rec_text, rec_score + """ + if self._runnable: + return self._model.predict(input_image) + return False + + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: list of rec_text, list of rec_score + """ + if self._runnable: + return self._model.batch_predict(images) + return False @property - def rec_img_w(self): - return self._model.rec_img_w + def preprocessor(self): + return self._model.preprocessor + + @preprocessor.setter + def preprocessor(self, value): + self._model.preprocessor = value @property - def rec_batch_num(self): - return self._model.rec_batch_num + def postprocessor(self): + return self._model.postprocessor - @rec_img_h.setter - def rec_img_h(self, value): + @postprocessor.setter + def postprocessor(self, value): + self._model.postprocessor = value + + @property + def is_scale(self): + return self._model.preprocessor.is_scale + + @is_scale.setter + def is_scale(self, value): assert isinstance( - value, int), "The value to set `rec_img_h` must be type of int." - self._model.rec_img_h = value + value, bool), "The value to set `is_scale` must be type of bool." + self._model.preprocessor.is_scale = value - @rec_img_w.setter - def rec_img_w(self, value): + @property + def scale(self): + return self._model.preprocessor.scale + + @scale.setter + def scale(self, value): assert isinstance( - value, int), "The value to set `rec_img_w` must be type of int." - self._model.rec_img_w = value + value, list), "The value to set `scale` must be type of list." + self._model.preprocessor.scale = value - @rec_batch_num.setter - def rec_batch_num(self, value): + @property + def mean(self): + return self._model.preprocessor.mean + + @mean.setter + def mean(self, value): + assert isinstance( + value, list), "The value to set `mean` must be type of list." + self._model.preprocessor.mean = value + + @property + def rec_image_shape(self): + return self._model.preprocessor.rec_image_shape + + @rec_image_shape.setter + def rec_image_shape(self, value): assert isinstance( value, - int), "The value to set `rec_batch_num` must be type of int." - self._model.rec_batch_num = value - ''' + list), "The value to set `rec_image_shape` must be type of list." + self._model.preprocessor.rec_image_shape = value class PPOCRv3(FastDeployModel): @@ -253,7 +686,6 @@ class PPOCRv3(FastDeployModel): def predict(self, input_image): """Predict an input image - :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :return: OCRResult """ @@ -264,9 +696,30 @@ class PPOCRv3(FastDeployModel): :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format :return: OCRBatchResult """ - return self.system.batch_predict(images) + @property + def cls_batch_size(self): + return self.system.cls_batch_size + + @cls_batch_size.setter + def cls_batch_size(self, value): + assert isinstance( + value, + int), "The value to set `cls_batch_size` must be type of int." + self.system.cls_batch_size = value + + @property + def rec_batch_size(self): + return self.system.rec_batch_size + + @rec_batch_size.setter + def rec_batch_size(self, value): + assert isinstance( + value, + int), "The value to set `rec_batch_size` must be type of int." + self.system.rec_batch_size = value + class PPOCRSystemv3(PPOCRv3): def __init__(self, det_model=None, cls_model=None, rec_model=None): @@ -311,6 +764,28 @@ class PPOCRv2(FastDeployModel): return self.system.batch_predict(images) + @property + def cls_batch_size(self): + return self.system.cls_batch_size + + @cls_batch_size.setter + def cls_batch_size(self, value): + assert isinstance( + value, + int), "The value to set `cls_batch_size` must be type of int." + self.system.cls_batch_size = value + + @property + def rec_batch_size(self): + return self.system.rec_batch_size + + @rec_batch_size.setter + def rec_batch_size(self, value): + assert isinstance( + value, + int), "The value to set `rec_batch_size` must be type of int." + self.system.rec_batch_size = value + class PPOCRSystemv2(PPOCRv2): def __init__(self, det_model=None, cls_model=None, rec_model=None): @@ -321,93 +796,3 @@ class PPOCRSystemv2(PPOCRv2): def predict(self, input_image): return super(PPOCRSystemv2, self).predict(input_image) - - -class DBDetectorPreprocessor: - def __init__(self): - """Create a preprocessor for DBDetectorModel - """ - self._preprocessor = C.vision.ocr.DBDetectorPreprocessor() - - def run(self, input_ims): - """Preprocess input images for DBDetectorModel - :param: input_ims: (list of numpy.ndarray) The input image - :return: pair(list of FDTensor, list of std::array) - """ - return self._preprocessor.run(input_ims) - - -class DBDetectorPostprocessor: - def __init__(self): - """Create a postprocessor for DBDetectorModel - """ - self._postprocessor = C.vision.ocr.DBDetectorPostprocessor() - - def run(self, runtime_results, batch_det_img_info): - """Postprocess the runtime results for DBDetectorModel - :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime - :param: batch_det_img_info: (list of std::array)The output of det_preprocessor - :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) - """ - return self._postprocessor.run(runtime_results, batch_det_img_info) - - -class RecognizerPreprocessor: - def __init__(self): - """Create a preprocessor for RecognizerModel - """ - self._preprocessor = C.vision.ocr.RecognizerPreprocessor() - - def run(self, input_ims): - """Preprocess input images for RecognizerModel - :param: input_ims: (list of numpy.ndarray)The input image - :return: list of FDTensor - """ - return self._preprocessor.run(input_ims) - - -class RecognizerPostprocessor: - def __init__(self, label_path): - """Create a postprocessor for RecognizerModel - :param label_path: (str)Path of label file - """ - self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path) - - def run(self, runtime_results): - """Postprocess the runtime results for RecognizerModel - :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime - :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) - """ - return self._postprocessor.run(runtime_results) - - -class ClassifierPreprocessor: - def __init__(self): - """Create a preprocessor for ClassifierModel - """ - self._preprocessor = C.vision.ocr.ClassifierPreprocessor() - - def run(self, input_ims): - """Preprocess input images for ClassifierModel - :param: input_ims: (list of numpy.ndarray)The input image - :return: list of FDTensor - """ - return self._preprocessor.run(input_ims) - - -class ClassifierPostprocessor: - def __init__(self): - """Create a postprocessor for ClassifierModel - """ - self._postprocessor = C.vision.ocr.ClassifierPostprocessor() - - def run(self, runtime_results): - """Postprocess the runtime results for ClassifierModel - :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime - :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) - """ - return self._postprocessor.run(runtime_results) - - -def sort_boxes(boxes): - return C.vision.ocr.sort_boxes(boxes) diff --git a/tests/models/test_ppocrv3.py b/tests/models/test_ppocrv3.py new file mode 100755 index 000000000..8440e8627 --- /dev/null +++ b/tests/models/test_ppocrv3.py @@ -0,0 +1,256 @@ +import fastdeploy as fd +import cv2 +import os +import runtime_config as rc +import numpy as np +import math +import pickle + +det_model_url = "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar" +cls_model_url = "https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar" +rec_model_url = "https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar" +img_url = "https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg" +label_url = "https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt" +result_url = "https://bj.bcebos.com/fastdeploy/tests/data/ocr_result.pickle" +fd.download_and_decompress(det_model_url, "resources") +fd.download_and_decompress(cls_model_url, "resources") +fd.download_and_decompress(rec_model_url, "resources") +fd.download(img_url, "resources") +fd.download(result_url, "resources") +fd.download(label_url, "resources") + + +def get_rotate_crop_image(img, box): + points = [] + for i in range(4): + points.append([box[2 * i], box[2 * i + 1]]) + points = np.array(points, dtype=np.float32) + img = img.astype(np.float32) + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +option = fd.RuntimeOption() + +# det_model +det_model_path = "resources/ch_PP-OCRv3_det_infer/" +det_model_file = det_model_path + "inference.pdmodel" +det_params_file = det_model_path + "inference.pdiparams" + +det_preprocessor = fd.vision.ocr.DBDetectorPreprocessor() + +rc.test_option.set_model_path(det_model_file, det_params_file) +det_runtime = fd.Runtime(rc.test_option) + +det_postprocessor = fd.vision.ocr.DBDetectorPostprocessor() + +det_model = fd.vision.ocr.DBDetector( + det_model_file, det_params_file, runtime_option=option) + +# cls_model +cls_model_path = "resources/ch_ppocr_mobile_v2.0_cls_infer/" +cls_model_file = cls_model_path + "inference.pdmodel" +cls_params_file = cls_model_path + "inference.pdiparams" + +cls_preprocessor = fd.vision.ocr.ClassifierPreprocessor() + +rc.test_option.set_model_path(cls_model_file, cls_params_file) +cls_runtime = fd.Runtime(rc.test_option) + +cls_postprocessor = fd.vision.ocr.ClassifierPostprocessor() + +cls_model = fd.vision.ocr.Classifier( + cls_model_file, cls_params_file, runtime_option=option) + +#rec_model +rec_model_path = "resources/ch_PP-OCRv3_rec_infer/" +rec_model_file = rec_model_path + "inference.pdmodel" +rec_params_file = rec_model_path + "inference.pdiparams" +rec_label_file = "resources/ppocr_keys_v1.txt" + +rec_preprocessor = fd.vision.ocr.RecognizerPreprocessor() + +rc.test_option.set_model_path(rec_model_file, rec_params_file) +rec_runtime = fd.Runtime(rc.test_option) + +rec_postprocessor = fd.vision.ocr.RecognizerPostprocessor(rec_label_file) + +rec_model = fd.vision.ocr.Recognizer( + rec_model_file, rec_params_file, rec_label_file, runtime_option=option) + +#pp_ocrv3 +ppocr_v3 = fd.vision.ocr.PPOCRv3( + det_model=det_model, cls_model=cls_model, rec_model=rec_model) + +#pp_ocrv3_no_cls +ppocr_v3_no_cls = fd.vision.ocr.PPOCRv3( + det_model=det_model, rec_model=rec_model) + +#input image +img_file = "resources/12.jpg" +im = [] +im.append(cv2.imread(img_file)) +im.append(cv2.imread(img_file)) + +result_file = "resources/ocr_result.pickle" +with open(result_file, 'rb') as f: + boxes, cls_labels, cls_scores, text, rec_scores = pickle.load(f) + base_boxes = np.array(boxes) + base_cls_labels = np.array(cls_labels) + base_cls_scores = np.array(cls_scores) + base_text = text + base_rec_scores = np.array(rec_scores) + + +def compare_result(pred_boxes, pred_cls_labels, pred_cls_scores, pred_text, + pred_rec_scores): + pred_boxes = np.array(pred_boxes) + pred_cls_labels = np.array(pred_cls_labels) + pred_cls_scores = np.array(pred_cls_scores) + pred_text = pred_text + pred_rec_scores = np.array(pred_rec_scores) + + diff_boxes = np.fabs(base_boxes - pred_boxes).max() + diff_cls_labels = np.fabs(base_cls_labels - pred_cls_labels).max() + diff_cls_scores = np.fabs(base_cls_scores - pred_cls_scores).max() + diff_text = (base_text != pred_text) + diff_rec_scores = np.fabs(base_rec_scores - pred_rec_scores).max() + + print('diff:', diff_boxes, diff_cls_labels, diff_cls_scores, diff_text, + diff_rec_scores) + diff_threshold = 1e-01 + assert diff_boxes < diff_threshold, "There is diff in boxes" + assert diff_cls_labels < diff_threshold, "There is diff in cls_label" + assert diff_cls_scores < diff_threshold, "There is diff in cls_scores" + assert diff_text < diff_threshold, "There is diff in text" + assert diff_rec_scores < diff_threshold, "There is diff in rec_scores" + + +def compare_result_no_cls(pred_boxes, pred_text, pred_rec_scores): + pred_boxes = np.array(pred_boxes) + pred_text = pred_text + pred_rec_scores = np.array(pred_rec_scores) + + diff_boxes = np.fabs(base_boxes - pred_boxes).max() + diff_text = (base_text != pred_text) + diff_rec_scores = np.fabs(base_rec_scores - pred_rec_scores).max() + + print('diff:', diff_boxes, diff_text, diff_rec_scores) + diff_threshold = 1e-01 + assert diff_boxes < diff_threshold, "There is diff in boxes" + assert diff_text < diff_threshold, "There is diff in text" + assert diff_rec_scores < diff_threshold, "There is diff in rec_scores" + + +def test_ppocr_v3(): + ppocr_v3.cls_batch_size = -1 + ppocr_v3.rec_batch_size = -1 + ocr_result = ppocr_v3.predict(im[0]) + compare_result(ocr_result.boxes, ocr_result.cls_labels, + ocr_result.cls_scores, ocr_result.text, + ocr_result.rec_scores) + + ppocr_v3.cls_batch_size = 2 + ppocr_v3.rec_batch_size = 2 + ocr_result = ppocr_v3.predict(im[0]) + compare_result(ocr_result.boxes, ocr_result.cls_labels, + ocr_result.cls_scores, ocr_result.text, + ocr_result.rec_scores) + + +def test_ppocr_v3_1(): + ppocr_v3_no_cls.cls_batch_size = -1 + ppocr_v3_no_cls.rec_batch_size = -1 + ocr_result = ppocr_v3_no_cls.predict(im[0]) + compare_result_no_cls(ocr_result.boxes, ocr_result.text, + ocr_result.rec_scores) + + ppocr_v3_no_cls.cls_batch_size = 2 + ppocr_v3_no_cls.rec_batch_size = 2 + ocr_result = ppocr_v3_no_cls.predict(im[0]) + compare_result_no_cls(ocr_result.boxes, ocr_result.text, + ocr_result.rec_scores) + + +def test_ppocr_v3_2(): + det_input_tensors, batch_det_img_info = det_preprocessor.run(im) + det_output_tensors = det_runtime.infer({"x": det_input_tensors[0]}) + det_results = det_postprocessor.run(det_output_tensors, batch_det_img_info) + + batch_boxes = [] + + batch_cls_labels = [] + batch_cls_scores = [] + + batch_rec_texts = [] + batch_rec_scores = [] + + for i_batch in range(len(det_results)): + cls_labels = [] + cls_scores = [] + rec_texts = [] + rec_scores = [] + box_list = fd.vision.ocr.sort_boxes(det_results[i_batch]) + batch_boxes.append(box_list) + image_list = [] + if len(box_list) == 0: + image_list.append(im[i_batch]) + else: + for box in box_list: + crop_img = get_rotate_crop_image(im[i_batch], box) + image_list.append(crop_img) + + cls_input_tensors = cls_preprocessor.run(image_list) + cls_output_tensors = cls_runtime.infer({"x": cls_input_tensors[0]}) + cls_labels, cls_scores = cls_postprocessor.run(cls_output_tensors) + + batch_cls_labels.append(cls_labels) + batch_cls_scores.append(cls_scores) + + for index in range(len(image_list)): + if cls_labels[index] == 1 and cls_scores[ + index] > cls_postprocessor.cls_thresh: + image_list[index] = cv2.rotate( + image_list[index].astype(np.float32), 1) + image_list[index] = np.astype(np.uint8) + + rec_input_tensors = rec_preprocessor.run(image_list) + rec_output_tensors = rec_runtime.infer({"x": rec_input_tensors[0]}) + rec_texts, rec_scores = rec_postprocessor.run(rec_output_tensors) + + batch_rec_texts.append(rec_texts) + batch_rec_scores.append(rec_scores) + + compare_result(box_list, cls_labels, cls_scores, rec_texts, rec_scores) + + +if __name__ == "__main__": + print("test test_ppocr_v3") + test_ppocr_v3() + test_ppocr_v3() + print("test test_ppocr_v3_1") + test_ppocr_v3_1() + test_ppocr_v3_1() + print("test test_ppocr_v3_2") + test_ppocr_v3_2() + test_ppocr_v3_2()