From d09fbd861ec1d7820c29f3eadfdc8adeab2ecbcb Mon Sep 17 00:00:00 2001 From: yunyaoXYY Date: Tue, 28 Feb 2023 07:51:22 +0000 Subject: [PATCH] Improve interface --- fastdeploy/vision/ocr/ppocr/classifier.cc | 47 ++- fastdeploy/vision/ocr/ppocr/classifier.h | 21 +- fastdeploy/vision/ocr/ppocr/dbdetector.cc | 43 ++- fastdeploy/vision/ocr/ppocr/dbdetector.h | 17 + .../vision/ocr/ppocr/ocrmodel_pybind.cc | 332 +++++++++++------- fastdeploy/vision/ocr/ppocr/recognizer.cc | 51 ++- fastdeploy/vision/ocr/ppocr/recognizer.h | 21 +- 7 files changed, 359 insertions(+), 173 deletions(-) mode change 100755 => 100644 fastdeploy/vision/ocr/ppocr/classifier.cc mode change 100755 => 100644 fastdeploy/vision/ocr/ppocr/dbdetector.cc mode change 100755 => 100644 fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc mode change 100755 => 100644 fastdeploy/vision/ocr/ppocr/recognizer.cc diff --git a/fastdeploy/vision/ocr/ppocr/classifier.cc b/fastdeploy/vision/ocr/ppocr/classifier.cc old mode 100755 new mode 100644 index 55f355db2..7da751bc8 --- a/fastdeploy/vision/ocr/ppocr/classifier.cc +++ b/fastdeploy/vision/ocr/ppocr/classifier.cc @@ -26,11 +26,11 @@ Classifier::Classifier(const std::string& model_file, const RuntimeOption& custom_option, const ModelFormat& model_format) { if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::ORT, - Backend::OPENVINO}; - valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; } else { - valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE}; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, + Backend::LITE}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_kunlunxin_backends = {Backend::LITE}; valid_ascend_backends = {Backend::LITE}; @@ -54,16 +54,18 @@ bool Classifier::Initialize() { } std::unique_ptr Classifier::Clone() const { - std::unique_ptr clone_model = utils::make_unique(Classifier(*this)); + std::unique_ptr clone_model = + utils::make_unique(Classifier(*this)); clone_model->SetRuntime(clone_model->CloneRuntime()); return clone_model; } -bool Classifier::Predict(const cv::Mat& img, int32_t* cls_label, float* cls_score) { +bool Classifier::Predict(const cv::Mat& img, int32_t* cls_label, + float* cls_score) { std::vector cls_labels(1); std::vector cls_scores(1); bool success = BatchPredict({img}, &cls_labels, &cls_scores); - if(!success){ + if (!success) { return success; } *cls_label = cls_labels[0]; @@ -71,17 +73,36 @@ bool Classifier::Predict(const cv::Mat& img, int32_t* cls_label, float* cls_scor return true; } +bool Classifier::Predict(const cv::Mat& img, vision::OCRResult* ocr_result) { + ocr_result->cls_labels.resize(1); + ocr_result->cls_scores.resize(1); + if (!Predict(img, &(ocr_result->cls_labels[0]), + &(ocr_result->cls_scores[0]))) { + return false; + } + return true; +} + bool Classifier::BatchPredict(const std::vector& images, - std::vector* cls_labels, std::vector* cls_scores) { + vision::OCRResult* ocr_result) { + return BatchPredict(images, &(ocr_result->cls_labels), + &(ocr_result->cls_scores)); +} + +bool Classifier::BatchPredict(const std::vector& images, + std::vector* cls_labels, + std::vector* cls_scores) { return BatchPredict(images, cls_labels, cls_scores, 0, images.size()); } bool Classifier::BatchPredict(const std::vector& images, - std::vector* cls_labels, std::vector* cls_scores, + std::vector* cls_labels, + std::vector* cls_scores, size_t start_index, size_t end_index) { size_t total_size = images.size(); std::vector fd_images = WrapMat(images); - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, + end_index)) { FDERROR << "Failed to preprocess the input image." << std::endl; return false; } @@ -91,8 +112,10 @@ bool Classifier::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores, start_index, total_size)) { - FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores, + start_index, total_size)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." + << std::endl; return false; } return true; diff --git a/fastdeploy/vision/ocr/ppocr/classifier.h b/fastdeploy/vision/ocr/ppocr/classifier.h index 324da828b..f6d2bd526 100755 --- a/fastdeploy/vision/ocr/ppocr/classifier.h +++ b/fastdeploy/vision/ocr/ppocr/classifier.h @@ -42,7 +42,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { Classifier(const std::string& model_file, const std::string& params_file = "", const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::PADDLE); - + /** \brief Clone a new Classifier with less memory usage when multiple instances of the same model are created * * \return new Classifier* type unique pointer @@ -61,7 +61,24 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { */ virtual bool Predict(const cv::Mat& img, int32_t* cls_label, float* cls_score); - + + /** \brief Predict the input image and get OCR recognition model result. + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_result The output of OCR recognition model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool Predict(const cv::Mat& img, vision::OCRResult* ocr_result); + + /** \brief BatchPredict the input image and get OCR classification model result. + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_result The output of OCR classification model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool BatchPredict(const std::vector& images, + vision::OCRResult* ocr_result); + /** \brief BatchPredict the input image and get OCR classification model cls_result. * * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.cc b/fastdeploy/vision/ocr/ppocr/dbdetector.cc old mode 100755 new mode 100644 index cd07cc262..0fb80fc44 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.cc +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.cc @@ -26,11 +26,11 @@ DBDetector::DBDetector(const std::string& model_file, const RuntimeOption& custom_option, const ModelFormat& model_format) { if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::ORT, - Backend::OPENVINO}; - valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; } else { - valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE}; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, + Backend::LITE}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_kunlunxin_backends = {Backend::LITE}; valid_ascend_backends = {Backend::LITE}; @@ -54,7 +54,8 @@ bool DBDetector::Initialize() { } std::unique_ptr DBDetector::Clone() const { - std::unique_ptr clone_model = utils::make_unique(DBDetector(*this)); + std::unique_ptr clone_model = + utils::make_unique(DBDetector(*this)); clone_model->SetRuntime(clone_model->CloneRuntime()); return clone_model; } @@ -69,11 +70,33 @@ bool DBDetector::Predict(const cv::Mat& img, return true; } +bool DBDetector::Predict(const cv::Mat& img, vision::OCRResult* ocr_result) { + if (!Predict(img, &(ocr_result->boxes))) { + return false; + } + return true; +} + bool DBDetector::BatchPredict(const std::vector& images, - std::vector>>* det_results) { + std::vector* ocr_results) { + std::vector>> det_results; + if (!BatchPredict(images, &det_results)) { + return false; + } + ocr_results->resize(det_results.size()); + for (int i = 0; i < det_results.size(); i++) { + (*ocr_results)[i].boxes = std::move(det_results[i]); + } + return true; +} + +bool DBDetector::BatchPredict( + const std::vector& images, + std::vector>>* det_results) { std::vector fd_images = WrapMat(images); std::vector> batch_det_img_info; - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, + &batch_det_img_info)) { FDERROR << "Failed to preprocess input image." << std::endl; return false; } @@ -84,8 +107,10 @@ bool DBDetector::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) { - FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, det_results, + batch_det_img_info)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." + << std::endl; return false; } return true; diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.h b/fastdeploy/vision/ocr/ppocr/dbdetector.h index cab3a1d39..60c47016f 100755 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.h +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.h @@ -62,6 +62,14 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { virtual bool Predict(const cv::Mat& img, std::vector>* boxes_result); + /** \brief Predict the input image and get OCR detection model result. + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_result The output of OCR detection model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool Predict(const cv::Mat& img, vision::OCRResult* ocr_result); + /** \brief BatchPredict the input image and get OCR detection model result. * * \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. @@ -71,6 +79,15 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { virtual bool BatchPredict(const std::vector& images, std::vector>>* det_results); + /** \brief BatchPredict the input image and get OCR detection model result. + * + * \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_results The output of OCR detection model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool BatchPredict(const std::vector& images, + std::vector* ocr_results); + /// Get preprocessor reference of DBDetectorPreprocessor virtual DBDetectorPreprocessor& GetPreprocessor() { return preprocessor_; diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc old mode 100755 new mode 100644 index 2bcb697a8..1d5609113 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -17,18 +17,26 @@ namespace fastdeploy { void BindPPOCRModel(pybind11::module& m) { m.def("sort_boxes", [](std::vector>& boxes) { - vision::ocr::SortBoxes(&boxes); - return boxes; + vision::ocr::SortBoxes(&boxes); + return boxes; }); - + // DBDetector - pybind11::class_(m, "DBDetectorPreprocessor") + pybind11::class_( + m, "DBDetectorPreprocessor") .def(pybind11::init<>()) - .def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) - .def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean) - .def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale) - .def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector& im_list) { + .def_property("max_side_len", + &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, + &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) + .def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, + &vision::ocr::DBDetectorPreprocessor::SetMean) + .def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, + &vision::ocr::DBDetectorPreprocessor::SetScale) + .def_property("is_scale", + &vision::ocr::DBDetectorPreprocessor::GetIsScale, + &vision::ocr::DBDetectorPreprocessor::SetIsScale) + .def("run", [](vision::ocr::DBDetectorPreprocessor& self, + std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); @@ -36,99 +44,134 @@ void BindPPOCRModel(pybind11::module& m) { std::vector outputs; std::vector> batch_det_img_info; self.Run(&images, &outputs, &batch_det_img_info); - for(size_t i = 0; i< outputs.size(); ++i){ + for (size_t i = 0; i < outputs.size(); ++i) { outputs[i].StopSharing(); } return std::make_pair(outputs, batch_det_img_info); }); - pybind11::class_(m, "DBDetectorPostprocessor") + pybind11::class_( + m, "DBDetectorPostprocessor") .def(pybind11::init<>()) - .def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) - .def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) - .def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) - .def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) - .def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation) + .def_property("det_db_thresh", + &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, + &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) + .def_property("det_db_box_thresh", + &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, + &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) + .def_property("det_db_unclip_ratio", + &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, + &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) + .def_property("det_db_score_mode", + &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, + &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) + .def_property("use_dilation", + &vision::ocr::DBDetectorPostprocessor::GetUseDilation, + &vision::ocr::DBDetectorPostprocessor::SetUseDilation) - .def("run", [](vision::ocr::DBDetectorPostprocessor& self, - std::vector& inputs, - const std::vector>& batch_det_img_info) { - std::vector>> results; + .def("run", + [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& inputs, + const std::vector>& batch_det_img_info) { + std::vector>> results; - if (!self.Run(inputs, &results, batch_det_img_info)) { - throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); - } - return results; - }) - .def("run", [](vision::ocr::DBDetectorPostprocessor& self, - std::vector& input_array, - const std::vector>& batch_det_img_info) { - std::vector>> results; - std::vector inputs; - PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); - if (!self.Run(inputs, &results, batch_det_img_info)) { - throw std::runtime_error("Failed to preprocess the input data in DBDetectorPostprocessor."); - } - return results; - }); + if (!self.Run(inputs, &results, batch_det_img_info)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "DBDetectorPostprocessor."); + } + return results; + }) + .def("run", + [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& input_array, + const std::vector>& batch_det_img_info) { + std::vector>> results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, batch_det_img_info)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "DBDetectorPostprocessor."); + } + return results; + }); pybind11::class_(m, "DBDetector") .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor) - .def("predict", [](vision::ocr::DBDetector& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - std::vector> boxes_result; - self.Predict(mat, &boxes_result); - return boxes_result; - }) - .def("batch_predict", [](vision::ocr::DBDetector& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::DBDetector::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::DBDetector::GetPostprocessor) + .def("predict", + [](vision::ocr::DBDetector& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::OCRResult ocr_result; + self.Predict(mat, &ocr_result); + return ocr_result; + }) + .def("batch_predict", [](vision::ocr::DBDetector& self, + std::vector& data) { std::vector images; - std::vector>> det_results; for (size_t i = 0; i < data.size(); ++i) { images.push_back(PyArrayToCvMat(data[i])); } - self.BatchPredict(images, &det_results); - return det_results; + std::vector ocr_results; + self.BatchPredict(images, &ocr_results); + return ocr_results; }); // Classifier - pybind11::class_(m, "ClassifierPreprocessor") + pybind11::class_( + m, "ClassifierPreprocessor") .def(pybind11::init<>()) - .def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape) - .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean) - .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale) - .def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector& im_list) { + .def_property("cls_image_shape", + &vision::ocr::ClassifierPreprocessor::GetClsImageShape, + &vision::ocr::ClassifierPreprocessor::SetClsImageShape) + .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, + &vision::ocr::ClassifierPreprocessor::SetMean) + .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, + &vision::ocr::ClassifierPreprocessor::SetScale) + .def_property("is_scale", + &vision::ocr::ClassifierPreprocessor::GetIsScale, + &vision::ocr::ClassifierPreprocessor::SetIsScale) + .def("run", [](vision::ocr::ClassifierPreprocessor& self, + std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); } std::vector outputs; if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPreprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in ClassifierPreprocessor."); } - for(size_t i = 0; i< outputs.size(); ++i){ + for (size_t i = 0; i < outputs.size(); ++i) { outputs[i].StopSharing(); } return outputs; }); - pybind11::class_(m, "ClassifierPostprocessor") + pybind11::class_( + m, "ClassifierPostprocessor") .def(pybind11::init<>()) - .def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh) - .def("run", [](vision::ocr::ClassifierPostprocessor& self, - std::vector& inputs) { - std::vector cls_labels; - std::vector cls_scores; - if (!self.Run(inputs, &cls_labels, &cls_scores)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); - } - return std::make_pair(cls_labels,cls_scores); - }) + .def_property("cls_thresh", + &vision::ocr::ClassifierPostprocessor::GetClsThresh, + &vision::ocr::ClassifierPostprocessor::SetClsThresh) + .def("run", + [](vision::ocr::ClassifierPostprocessor& self, + std::vector& inputs) { + std::vector cls_labels; + std::vector cls_scores; + if (!self.Run(inputs, &cls_labels, &cls_scores)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "ClassifierPostprocessor."); + } + return std::make_pair(cls_labels, cls_scores); + }) .def("run", [](vision::ocr::ClassifierPostprocessor& self, std::vector& input_array) { std::vector inputs; @@ -136,70 +179,88 @@ void BindPPOCRModel(pybind11::module& m) { std::vector cls_labels; std::vector cls_scores; if (!self.Run(inputs, &cls_labels, &cls_scores)) { - throw std::runtime_error("Failed to preprocess the input data in ClassifierPostprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in " + "ClassifierPostprocessor."); } - return std::make_pair(cls_labels,cls_scores); + return std::make_pair(cls_labels, cls_scores); }); - + pybind11::class_(m, "Classifier") .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor) - .def("predict", [](vision::ocr::Classifier& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - int32_t cls_label; - float cls_score; - self.Predict(mat, &cls_label, &cls_score); - return std::make_pair(cls_label, cls_score); - }) - .def("batch_predict", [](vision::ocr::Classifier& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::Classifier::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::Classifier::GetPostprocessor) + .def("predict", + [](vision::ocr::Classifier& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::OCRResult ocr_result; + self.Predict(mat, &ocr_result); + return ocr_result; + }) + .def("batch_predict", [](vision::ocr::Classifier& self, + std::vector& data) { std::vector images; - std::vector cls_labels; - std::vector cls_scores; for (size_t i = 0; i < data.size(); ++i) { images.push_back(PyArrayToCvMat(data[i])); } - self.BatchPredict(images, &cls_labels, &cls_scores); - return std::make_pair(cls_labels, cls_scores); + vision::OCRResult ocr_result; + self.BatchPredict(images, &ocr_result); + return ocr_result; }); // Recognizer - pybind11::class_(m, "RecognizerPreprocessor") - .def(pybind11::init<>()) - .def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) - .def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape) - .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean) - .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale) - .def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale) - .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { - std::vector images; - for (size_t i = 0; i < im_list.size(); ++i) { - images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); - } - std::vector outputs; - if (!self.Run(&images, &outputs)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPreprocessor."); - } - for(size_t i = 0; i< outputs.size(); ++i){ - outputs[i].StopSharing(); - } - return outputs; - }); - - pybind11::class_(m, "RecognizerPostprocessor") - .def(pybind11::init()) - .def("run", [](vision::ocr::RecognizerPostprocessor& self, - std::vector& inputs) { - std::vector texts; - std::vector rec_scores; - if (!self.Run(inputs, &texts, &rec_scores)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); + pybind11::class_( + m, "RecognizerPreprocessor") + .def(pybind11::init<>()) + .def_property("static_shape_infer", + &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, + &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) + .def_property("rec_image_shape", + &vision::ocr::RecognizerPreprocessor::GetRecImageShape, + &vision::ocr::RecognizerPreprocessor::SetRecImageShape) + .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, + &vision::ocr::RecognizerPreprocessor::SetMean) + .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, + &vision::ocr::RecognizerPreprocessor::SetScale) + .def_property("is_scale", + &vision::ocr::RecognizerPreprocessor::GetIsScale, + &vision::ocr::RecognizerPreprocessor::SetIsScale) + .def("run", [](vision::ocr::RecognizerPreprocessor& self, + std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); } - return std::make_pair(texts, rec_scores); - }) + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error( + "Failed to preprocess the input data in RecognizerPreprocessor."); + } + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + return outputs; + }); + + pybind11::class_( + m, "RecognizerPostprocessor") + .def(pybind11::init()) + .def("run", + [](vision::ocr::RecognizerPostprocessor& self, + std::vector& inputs) { + std::vector texts; + std::vector rec_scores; + if (!self.Run(inputs, &texts, &rec_scores)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "RecognizerPostprocessor."); + } + return std::make_pair(texts, rec_scores); + }) .def("run", [](vision::ocr::RecognizerPostprocessor& self, std::vector& input_array) { std::vector inputs; @@ -207,7 +268,9 @@ void BindPPOCRModel(pybind11::module& m) { std::vector texts; std::vector rec_scores; if (!self.Run(inputs, &texts, &rec_scores)) { - throw std::runtime_error("Failed to preprocess the input data in RecognizerPostprocessor."); + throw std::runtime_error( + "Failed to preprocess the input data in " + "RecognizerPostprocessor."); } return std::make_pair(texts, rec_scores); }); @@ -216,25 +279,26 @@ void BindPPOCRModel(pybind11::module& m) { .def(pybind11::init()) .def(pybind11::init<>()) - .def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor) - .def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor) - .def("predict", [](vision::ocr::Recognizer& self, - pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - std::string text; - float rec_score; - self.Predict(mat, &text, &rec_score); - return std::make_pair(text, rec_score); - }) - .def("batch_predict", [](vision::ocr::Recognizer& self, std::vector& data) { + .def_property_readonly("preprocessor", + &vision::ocr::Recognizer::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::Recognizer::GetPostprocessor) + .def("predict", + [](vision::ocr::Recognizer& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::OCRResult ocr_result; + self.Predict(mat, &ocr_result); + return ocr_result; + }) + .def("batch_predict", [](vision::ocr::Recognizer& self, + std::vector& data) { std::vector images; - std::vector texts; - std::vector rec_scores; for (size_t i = 0; i < data.size(); ++i) { images.push_back(PyArrayToCvMat(data[i])); } - self.BatchPredict(images, &texts, &rec_scores); - return std::make_pair(texts, rec_scores); + vision::OCRResult ocr_result; + self.BatchPredict(images, &ocr_result); + return ocr_result; }); } } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.cc b/fastdeploy/vision/ocr/ppocr/recognizer.cc old mode 100755 new mode 100644 index 69e75b281..fe8045b3a --- a/fastdeploy/vision/ocr/ppocr/recognizer.cc +++ b/fastdeploy/vision/ocr/ppocr/recognizer.cc @@ -26,16 +26,17 @@ Recognizer::Recognizer(const std::string& model_file, const std::string& params_file, const std::string& label_path, const RuntimeOption& custom_option, - const ModelFormat& model_format):postprocessor_(label_path) { + const ModelFormat& model_format) + : postprocessor_(label_path) { if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::ORT, - Backend::OPENVINO}; - valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; } else { - valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, Backend::LITE}; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, + Backend::LITE}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; valid_kunlunxin_backends = {Backend::LITE}; - valid_ascend_backends = {Backend::LITE}; + valid_ascend_backends = {Backend::LITE}; valid_sophgonpu_backends = {Backend::SOPHGOTPU}; } @@ -57,12 +58,14 @@ bool Recognizer::Initialize() { } std::unique_ptr Recognizer::Clone() const { - std::unique_ptr clone_model = utils::make_unique(Recognizer(*this)); + std::unique_ptr clone_model = + utils::make_unique(Recognizer(*this)); clone_model->SetRuntime(clone_model->CloneRuntime()); return clone_model; } -bool Recognizer::Predict(const cv::Mat& img, std::string* text, float* rec_score) { +bool Recognizer::Predict(const cv::Mat& img, std::string* text, + float* rec_score) { std::vector texts(1); std::vector rec_scores(1); bool success = BatchPredict({img}, &texts, &rec_scores); @@ -74,21 +77,39 @@ bool Recognizer::Predict(const cv::Mat& img, std::string* text, float* rec_score return true; } +bool Recognizer::Predict(const cv::Mat& img, vision::OCRResult* ocr_result) { + ocr_result->text.resize(1); + ocr_result->rec_scores.resize(1); + if (!Predict(img, &(ocr_result->text[0]), &(ocr_result->rec_scores[0]))) { + return false; + } + return true; +} + bool Recognizer::BatchPredict(const std::vector& images, - std::vector* texts, std::vector* rec_scores) { + std::vector* texts, + std::vector* rec_scores) { return BatchPredict(images, texts, rec_scores, 0, images.size(), {}); } bool Recognizer::BatchPredict(const std::vector& images, - std::vector* texts, std::vector* rec_scores, - size_t start_index, size_t end_index, const std::vector& indices) { + vision::OCRResult* ocr_result) { + return BatchPredict(images, &(ocr_result->text), &(ocr_result->rec_scores)); +} + +bool Recognizer::BatchPredict(const std::vector& images, + std::vector* texts, + std::vector* rec_scores, + size_t start_index, size_t end_index, + const std::vector& indices) { size_t total_size = images.size(); if (indices.size() != 0 && indices.size() != total_size) { FDERROR << "indices.size() should be 0 or images.size()." << std::endl; return false; } std::vector fd_images = WrapMat(images); - if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, end_index, indices)) { + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, start_index, + end_index, indices)) { FDERROR << "Failed to preprocess the input image." << std::endl; return false; } @@ -99,8 +120,10 @@ bool Recognizer::BatchPredict(const std::vector& images, return false; } - if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores, start_index, total_size, indices)) { - FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores, + start_index, total_size, indices)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." + << std::endl; return false; } return true; diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.h b/fastdeploy/vision/ocr/ppocr/recognizer.h index 072c19129..5cafb6852 100755 --- a/fastdeploy/vision/ocr/ppocr/recognizer.h +++ b/fastdeploy/vision/ocr/ppocr/recognizer.h @@ -44,7 +44,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { const std::string& label_path = "", const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::PADDLE); - + /// Get model's name std::string ModelName() const { return "ppocr/ocr_rec"; } @@ -63,6 +63,23 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { */ virtual bool Predict(const cv::Mat& img, std::string* text, float* rec_score); + /** \brief Predict the input image and get OCR recognition model result. + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_result The output of OCR recognition model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool Predict(const cv::Mat& img, vision::OCRResult* ocr_result); + + /** \brief BatchPredict the input image and get OCR recognition model result. + * + * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] ocr_result The output of OCR recognition model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool BatchPredict(const std::vector& images, + vision::OCRResult* ocr_result); + /** \brief BatchPredict the input image and get OCR recognition model result. * * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. @@ -72,7 +89,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { */ virtual bool BatchPredict(const std::vector& images, std::vector* texts, std::vector* rec_scores); - + virtual bool BatchPredict(const std::vector& images, std::vector* texts, std::vector* rec_scores, size_t start_index, size_t end_index,