From b6903b0aa42b4c74f727f1cca04cf79a3b80c02d Mon Sep 17 00:00:00 2001 From: yunyaoXYY Date: Wed, 28 Dec 2022 12:56:53 +0000 Subject: [PATCH] Add PPOCR example for ascend deploy --- examples/vision/ocr/PP-OCRv2/cpp/README.md | 2 + examples/vision/ocr/PP-OCRv2/cpp/infer.cc | 13 ++++- examples/vision/ocr/PP-OCRv2/python/README.md | 2 + examples/vision/ocr/PP-OCRv2/python/infer.py | 10 ++++ examples/vision/ocr/PP-OCRv3/cpp/README.md | 2 + examples/vision/ocr/PP-OCRv3/cpp/infer.cc | 15 ++++-- examples/vision/ocr/PP-OCRv3/python/README.md | 2 + examples/vision/ocr/PP-OCRv3/python/infer.py | 10 ++++ fastdeploy/vision/ocr/ppocr/classifier.h | 13 ++++- .../vision/ocr/ppocr/cls_postprocessor.h | 6 +++ .../vision/ocr/ppocr/cls_preprocessor.h | 21 ++++++++ fastdeploy/vision/ocr/ppocr/dbdetector.h | 13 ++++- .../vision/ocr/ppocr/det_postprocessor.h | 28 ++++++++++ .../vision/ocr/ppocr/det_preprocessor.h | 20 ++++++++ .../vision/ocr/ppocr/ocrmodel_pybind.cc | 51 ++++++++++--------- fastdeploy/vision/ocr/ppocr/ppocr_v2.cc | 6 +-- fastdeploy/vision/ocr/ppocr/ppocr_v3.h | 4 +- .../vision/ocr/ppocr/rec_preprocessor.cc | 23 ++------- .../vision/ocr/ppocr/rec_preprocessor.h | 31 ++++++++++- fastdeploy/vision/ocr/ppocr/recognizer.h | 13 ++++- .../fastdeploy/vision/ocr/ppocr/__init__.py | 24 ++++----- 21 files changed, 235 insertions(+), 74 deletions(-) diff --git a/examples/vision/ocr/PP-OCRv2/cpp/README.md b/examples/vision/ocr/PP-OCRv2/cpp/README.md index fbde53fff..e30d886d1 100755 --- a/examples/vision/ocr/PP-OCRv2/cpp/README.md +++ b/examples/vision/ocr/PP-OCRv2/cpp/README.md @@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_ ./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3 # 昆仑芯XPU推理 ./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4 +# 华为昇腾推理 +./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5 ``` 以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: diff --git a/examples/vision/ocr/PP-OCRv2/cpp/infer.cc b/examples/vision/ocr/PP-OCRv2/cpp/infer.cc index 3406246aa..0248367cc 100755 --- a/examples/vision/ocr/PP-OCRv2/cpp/infer.cc +++ b/examples/vision/ocr/PP-OCRv2/cpp/infer.cc @@ -55,6 +55,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option); auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option); + // Users could enable static shape infer for rec model when deploy PP-OCR on hardware + // which can not support dynamic shape infer well, like Huawei Ascend series. + // rec_model.GetPreprocessor().SetStaticShapeInfer(true); + assert(det_model.Initialized()); assert(cls_model.Initialized()); assert(rec_model.Initialized()); @@ -65,9 +69,12 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model // Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity. // When inference batch size is set to -1, it means that the inference batch size - // of the cls and rec models will be the same as the number of boxes detected by the det model. + // of the cls and rec models will be the same as the number of boxes detected by the det model. + // When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1. + // ppocr_v2.SetClsBatchSize(1); + // ppocr_v2.SetRecBatchSize(1); ppocr_v2.SetClsBatchSize(cls_batch_size); - ppocr_v2.SetRecBatchSize(rec_batch_size); + ppocr_v2.SetRecBatchSize(rec_batch_size); if(!ppocr_v2.Initialized()){ std::cerr << "Failed to initialize PP-OCR." << std::endl; @@ -122,6 +129,8 @@ int main(int argc, char* argv[]) { option.EnablePaddleToTrt(); } else if (flag == 4) { option.UseKunlunXin(); + } else if (flag == 5) { + option.UseAscend(); } std::string det_model_dir = argv[1]; diff --git a/examples/vision/ocr/PP-OCRv2/python/README.md b/examples/vision/ocr/PP-OCRv2/python/README.md index 66bba9e5b..270225ab7 100755 --- a/examples/vision/ocr/PP-OCRv2/python/README.md +++ b/examples/vision/ocr/PP-OCRv2/python/README.md @@ -36,6 +36,8 @@ python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2 python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt # 昆仑芯XPU推理 python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin +# 华为昇腾推理 +python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend ``` 运行完成可视化结果如下图所示 diff --git a/examples/vision/ocr/PP-OCRv2/python/infer.py b/examples/vision/ocr/PP-OCRv2/python/infer.py index b8c731ef3..f7373b4c2 100755 --- a/examples/vision/ocr/PP-OCRv2/python/infer.py +++ b/examples/vision/ocr/PP-OCRv2/python/infer.py @@ -72,6 +72,10 @@ def build_option(args): option.use_kunlunxin() return option + if args.device.lower() == "ascend": + option.use_ascend() + return option + if args.backend.lower() == "trt": assert args.device.lower( ) == "gpu", "TensorRT backend require inference on device GPU." @@ -112,6 +116,8 @@ runtime_option = build_option(args) # PPOCR的cls和rec模型现在已经支持推理一个Batch的数据 # 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置 +# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾) +# 需要把cls_batch_size和rec_batch_size都设置为1. cls_batch_size = 1 rec_batch_size = 6 @@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 32, 10], rec_model = fd.vision.ocr.Recognizer( rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option) +# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾) +# 需要使用下行代码, 来启用rec模型的静态shape推理. +# rec_model.preprocessor.static_shape_infer = True + # 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None ppocr_v2 = fd.vision.ocr.PPOCRv2( det_model=det_model, cls_model=cls_model, rec_model=rec_model) diff --git a/examples/vision/ocr/PP-OCRv3/cpp/README.md b/examples/vision/ocr/PP-OCRv3/cpp/README.md index 9c5eff4ef..6f48a69ac 100755 --- a/examples/vision/ocr/PP-OCRv3/cpp/README.md +++ b/examples/vision/ocr/PP-OCRv3/cpp/README.md @@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_ ./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3 # 昆仑芯XPU推理 ./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4 +# 华为昇腾推理, 请用户在代码里正确开启Rec模型的静态shape推理,并设置分类模型和识别模型的推理batch size为1. +./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5 ``` 以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: diff --git a/examples/vision/ocr/PP-OCRv3/cpp/infer.cc b/examples/vision/ocr/PP-OCRv3/cpp/infer.cc index fd25eca7e..7fbcf835e 100755 --- a/examples/vision/ocr/PP-OCRv3/cpp/infer.cc +++ b/examples/vision/ocr/PP-OCRv3/cpp/infer.cc @@ -34,7 +34,7 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model auto rec_option = option; // The cls and rec model can inference a batch of images now. - // User could initialize the inference batch size and set them after create PPOCR model. + // User could initialize the inference batch size and set them after create PP-OCR model. int cls_batch_size = 1; int rec_batch_size = 6; @@ -56,6 +56,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option); auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option); + // Users could enable static shape infer for rec model when deploy PP-OCR on hardware + // which can not support dynamic shape infer well, like Huawei Ascend series. + // rec_model.GetPreprocessor().SetStaticShapeInfer(true); + assert(det_model.Initialized()); assert(cls_model.Initialized()); assert(rec_model.Initialized()); @@ -66,9 +70,12 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model // Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity. // When inference batch size is set to -1, it means that the inference batch size - // of the cls and rec models will be the same as the number of boxes detected by the det model. + // of the cls and rec models will be the same as the number of boxes detected by the det model. + // When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1. + // ppocr_v3.SetClsBatchSize(1); + // ppocr_v3.SetRecBatchSize(1); ppocr_v3.SetClsBatchSize(cls_batch_size); - ppocr_v3.SetRecBatchSize(rec_batch_size); + ppocr_v3.SetRecBatchSize(rec_batch_size); if(!ppocr_v3.Initialized()){ std::cerr << "Failed to initialize PP-OCR." << std::endl; @@ -123,6 +130,8 @@ int main(int argc, char* argv[]) { option.EnablePaddleToTrt(); } else if (flag == 4) { option.UseKunlunXin(); + } else if (flag == 5) { + option.UseAscend(); } std::string det_model_dir = argv[1]; diff --git a/examples/vision/ocr/PP-OCRv3/python/README.md b/examples/vision/ocr/PP-OCRv3/python/README.md index e87729353..dd5965d33 100755 --- a/examples/vision/ocr/PP-OCRv3/python/README.md +++ b/examples/vision/ocr/PP-OCRv3/python/README.md @@ -35,6 +35,8 @@ python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2 python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt # 昆仑芯XPU推理 python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin +# 华为昇腾推理 +python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend ``` 运行完成可视化结果如下图所示 diff --git a/examples/vision/ocr/PP-OCRv3/python/infer.py b/examples/vision/ocr/PP-OCRv3/python/infer.py index 97ee1d070..f6da98bdb 100755 --- a/examples/vision/ocr/PP-OCRv3/python/infer.py +++ b/examples/vision/ocr/PP-OCRv3/python/infer.py @@ -72,6 +72,10 @@ def build_option(args): option.use_kunlunxin() return option + if args.device.lower() == "ascend": + option.use_ascend() + return option + if args.backend.lower() == "trt": assert args.device.lower( ) == "gpu", "TensorRT backend require inference on device GPU." @@ -112,6 +116,8 @@ runtime_option = build_option(args) # PPOCR的cls和rec模型现在已经支持推理一个Batch的数据 # 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置 +# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾) +# 需要把cls_batch_size和rec_batch_size都设置为1. cls_batch_size = 1 rec_batch_size = 6 @@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 48, 10], rec_model = fd.vision.ocr.Recognizer( rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option) +# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾) +# 需要使用下行代码, 来启用rec模型的静态shape推理. +# rec_model.preprocessor.static_shape_infer = True + # 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None ppocr_v3 = fd.vision.ocr.PPOCRv3( det_model=det_model, cls_model=cls_model, rec_model=rec_model) diff --git a/fastdeploy/vision/ocr/ppocr/classifier.h b/fastdeploy/vision/ocr/ppocr/classifier.h index cd035e269..824d9c3be 100755 --- a/fastdeploy/vision/ocr/ppocr/classifier.h +++ b/fastdeploy/vision/ocr/ppocr/classifier.h @@ -68,11 +68,20 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { std::vector* cls_scores, size_t start_index, size_t end_index); - ClassifierPreprocessor preprocessor_; - ClassifierPostprocessor postprocessor_; + /// Get preprocessor reference of ClassifierPreprocessor + virtual ClassifierPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of ClassifierPostprocessor + virtual ClassifierPostprocessor& GetPostprocessor() { + return postprocessor_; + } private: bool Initialize(); + ClassifierPreprocessor preprocessor_; + ClassifierPostprocessor postprocessor_; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h index d9702e1a1..e596db71d 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h @@ -39,6 +39,12 @@ class FASTDEPLOY_DECL ClassifierPostprocessor { std::vector* cls_labels, std::vector* cls_scores, size_t start_index, size_t total_size); + /// Set threshold for the classification postprocess, default is 0.9 + void SetClsThresh(float cls_thresh) { cls_thresh_ = cls_thresh; } + + /// Get threshold value of the classification postprocess. + float GetClsThresh() const { return cls_thresh_; } + float cls_thresh_ = 0.9; }; diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h index 8c1c81611..8d42f3d31 100644 --- a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h @@ -34,6 +34,27 @@ class FASTDEPLOY_DECL ClassifierPreprocessor { bool Run(std::vector* images, std::vector* outputs, size_t start_index, size_t end_index); + /// Set mean value for the image normalization in classification preprocess + void SetMean(std::vector mean) { mean_ = mean; } + /// Get mean value of the image normalization in classification preprocess + std::vector GetMean() const { return mean_; } + + /// Set scale value for the image normalization in classification preprocess + void SetScale(std::vector scale) { scale_ = scale; } + /// Get scale value of the image normalization in classification preprocess + std::vector GetScale() const { return scale_; } + + /// Set is_scale for the image normalization in classification preprocess + void SetIsScale(bool is_scale) { is_scale_ = is_scale; } + /// Get is_scale of the image normalization in classification preprocess + bool GetIsScale() const { return is_scale_; } + + /// Set cls_image_shape for the classification preprocess + void SetClsImageShape(std::vector cls_image_shape) + { cls_image_shape_ = cls_image_shape; } + /// Get cls_image_shape for the classification preprocess + std::vector GetClsImageShape() const { return cls_image_shape_; } + std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {0.5f, 0.5f, 0.5f}; bool is_scale_ = true; diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.h b/fastdeploy/vision/ocr/ppocr/dbdetector.h index d2305abd7..ec1ef028d 100755 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.h +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.h @@ -61,11 +61,20 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { virtual bool BatchPredict(const std::vector& images, std::vector>>* det_results); - DBDetectorPreprocessor preprocessor_; - DBDetectorPostprocessor postprocessor_; + /// Get preprocessor reference of DBDetectorPreprocessor + virtual DBDetectorPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of DBDetectorPostprocessor + virtual DBDetectorPostprocessor& GetPostprocessor() { + return postprocessor_; + } private: bool Initialize(); + DBDetectorPreprocessor preprocessor_; + DBDetectorPostprocessor postprocessor_; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/det_postprocessor.h b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h index 115228843..129ca6258 100644 --- a/fastdeploy/vision/ocr/ppocr/det_postprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h @@ -36,6 +36,34 @@ class FASTDEPLOY_DECL DBDetectorPostprocessor { std::vector>>* results, const std::vector>& batch_det_img_info); + /// Set det_db_thresh for the detection postprocess, default is 0.3 + void SetDetDBThresh(double det_db_thresh) { det_db_thresh_ = det_db_thresh; } + /// Get det_db_thresh of the detection postprocess + double GetDetDBThresh() const { return det_db_thresh_; } + + /// Set det_db_box_thresh for the detection postprocess, default is 0.6 + void SetDetDBBoxThresh(double det_db_box_thresh) + { det_db_box_thresh_ = det_db_box_thresh; } + /// Get det_db_box_thresh of the detection postprocess + double GetDetDBBoxThresh() const { return det_db_box_thresh_; } + + /// Set det_db_unclip_ratio for the detection postprocess, default is 1.5 + void SetDetDBUnclipRatio(double det_db_unclip_ratio) + { det_db_unclip_ratio_ = det_db_unclip_ratio; } + /// Get det_db_unclip_ratio_ of the detection postprocess + double GetDetDBUnclipRatio() const { return det_db_unclip_ratio_; } + + /// Set det_db_score_mode for the detection postprocess, default is 'slow' + void SetDetDBScoreMode(std::string det_db_score_mode) + { det_db_score_mode_ = det_db_score_mode; } + /// Get det_db_score_mode_ of the detection postprocess + std::string GetDetDBScoreMode() const { return det_db_score_mode_; } + + /// Set use_dilation for the detection postprocess, default is fasle + void SetUseDilation(int use_dilation) { use_dilation_ = use_dilation; } + /// Get use_dilation of the detection postprocess + int GetUseDilation() const { return use_dilation_; } + double det_db_thresh_ = 0.3; double det_db_box_thresh_ = 0.6; double det_db_unclip_ratio_ = 1.5; diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h index 705f19c7b..bf496079f 100644 --- a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h @@ -35,6 +35,26 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor { std::vector* outputs, std::vector>* batch_det_img_info_ptr); + /// Set max_side_len for the detection preprocess, default is 960 + void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; } + /// Get max_side_len of the detection preprocess + int GetMaxSideLen() const { return max_side_len_; } + + /// Set mean value for the image normalization in detection preprocess + void SetMean(std::vector mean) { mean_ = mean; } + /// Get mean value of the image normalization in detection preprocess + std::vector GetMean() const { return mean_; } + + /// Set scale value for the image normalization in detection preprocess + void SetScale(std::vector scale) { scale_ = scale; } + /// Get scale value of the image normalization in detection preprocess + std::vector GetScale() const { return scale_; } + + /// Set is_scale for the image normalization in detection preprocess + void SetIsScale(bool is_scale) { is_scale_ = is_scale; } + /// Get is_scale of the image normalization in detection preprocess + bool GetIsScale() const { return is_scale_; } + int max_side_len_ = 960; std::vector mean_ = {0.485f, 0.456f, 0.406f}; std::vector scale_ = {0.229f, 0.224f, 0.225f}; diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc index acc73c57d..2bcb697a8 100755 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -24,10 +24,10 @@ void BindPPOCRModel(pybind11::module& m) { // DBDetector pybind11::class_(m, "DBDetectorPreprocessor") .def(pybind11::init<>()) - .def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_) - .def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_) - .def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_) - .def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_) + .def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen) + .def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean) + .def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale) + .def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale) .def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { @@ -44,11 +44,12 @@ void BindPPOCRModel(pybind11::module& m) { pybind11::class_(m, "DBDetectorPostprocessor") .def(pybind11::init<>()) - .def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_) - .def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_) - .def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_) - .def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_) - .def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_) + .def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh) + .def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh) + .def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio) + .def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode) + .def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation) + .def("run", [](vision::ocr::DBDetectorPostprocessor& self, std::vector& inputs, const std::vector>& batch_det_img_info) { @@ -75,8 +76,8 @@ void BindPPOCRModel(pybind11::module& m) { .def(pybind11::init()) .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_) + .def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor) .def("predict", [](vision::ocr::DBDetector& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); @@ -97,10 +98,10 @@ void BindPPOCRModel(pybind11::module& m) { // Classifier pybind11::class_(m, "ClassifierPreprocessor") .def(pybind11::init<>()) - .def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_) - .def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_) - .def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_) - .def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_) + .def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape) + .def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean) + .def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale) + .def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale) .def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { @@ -118,7 +119,7 @@ void BindPPOCRModel(pybind11::module& m) { pybind11::class_(m, "ClassifierPostprocessor") .def(pybind11::init<>()) - .def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_) + .def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh) .def("run", [](vision::ocr::ClassifierPostprocessor& self, std::vector& inputs) { std::vector cls_labels; @@ -144,8 +145,8 @@ void BindPPOCRModel(pybind11::module& m) { .def(pybind11::init()) .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_) + .def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor) .def("predict", [](vision::ocr::Classifier& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); @@ -168,11 +169,11 @@ void BindPPOCRModel(pybind11::module& m) { // Recognizer pybind11::class_(m, "RecognizerPreprocessor") .def(pybind11::init<>()) - .def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_) - .def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_) - .def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_) - .def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_) - .def_readwrite("static_shape", &vision::ocr::RecognizerPreprocessor::static_shape_) + .def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer) + .def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape) + .def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean) + .def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale) + .def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale) .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { std::vector images; for (size_t i = 0; i < im_list.size(); ++i) { @@ -215,8 +216,8 @@ void BindPPOCRModel(pybind11::module& m) { .def(pybind11::init()) .def(pybind11::init<>()) - .def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_) - .def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_) + .def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor) + .def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor) .def("predict", [](vision::ocr::Recognizer& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc index 756604dde..622fe41c0 100755 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc @@ -23,14 +23,14 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, fastdeploy::vision::ocr::Recognizer* rec_model) : detector_(det_model), classifier_(cls_model), recognizer_(rec_model) { Initialized(); - recognizer_->preprocessor_.rec_image_shape_[1] = 32; + recognizer_->GetPreprocessor().rec_image_shape_[1] = 32; } PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, fastdeploy::vision::ocr::Recognizer* rec_model) : detector_(det_model), recognizer_(rec_model) { Initialized(); - recognizer_->preprocessor_.rec_image_shape_[1] = 32; + recognizer_->GetPreprocessor().rec_image_shape_[1] = 32; } bool PPOCRv2::SetClsBatchSize(int cls_batch_size) { @@ -134,7 +134,7 @@ bool PPOCRv2::BatchPredict(const std::vector& images, return false; }else{ for (size_t i_img = start_index; i_img < end_index; ++i_img) { - if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) { + if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->GetPostprocessor().cls_thresh_) { cv::rotate(image_list[i_img], image_list[i_img], 1); } } diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v3.h b/fastdeploy/vision/ocr/ppocr/ppocr_v3.h index ed9177d92..fa46fdb2c 100755 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v3.h +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v3.h @@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 { fastdeploy::vision::ocr::Recognizer* rec_model) : PPOCRv2(det_model, cls_model, rec_model) { // The only difference between v2 and v3 - recognizer_->preprocessor_.rec_image_shape_[1] = 48; + recognizer_->GetPreprocessor().rec_image_shape_[1] = 48; } /** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively. * @@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 { fastdeploy::vision::ocr::Recognizer* rec_model) : PPOCRv2(det_model, rec_model) { // The only difference between v2 and v3 - recognizer_->preprocessor_.rec_image_shape_[1] = 48; + recognizer_->GetPreprocessor().rec_image_shape_[1] = 48; } }; diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc index 8ed4e0c53..ad049fdce 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc @@ -22,12 +22,12 @@ namespace vision { namespace ocr { void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, - const std::vector& rec_image_shape, bool static_shape) { + const std::vector& rec_image_shape, bool static_shape_infer) { int img_h, img_w; img_h = rec_image_shape[1]; img_w = rec_image_shape[2]; - if (!static_shape) { + if (!static_shape_infer) { img_w = int(img_h * max_wh_ratio); float ratio = float(mat->Width()) / float(mat->Height()); @@ -52,23 +52,6 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, } } -void OcrRecognizerResizeImageOnAscend(FDMat* mat, - const std::vector& rec_image_shape) { - - int img_h, img_w; - img_h = rec_image_shape[1]; - img_w = rec_image_shape[2]; - - if (mat->Width() >= img_w) { - Resize::Run(mat, img_w, img_h); // Reszie W to 320 - } else { - Resize::Run(mat, mat->Width(), img_h); - Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {0,0,0}); - // Pad to 320 - } -} - - bool RecognizerPreprocessor::Run(std::vector* images, std::vector* outputs) { return Run(images, outputs, 0, images->size(), {}); } @@ -101,7 +84,7 @@ bool RecognizerPreprocessor::Run(std::vector* images, std::vectorat(real_index)); - OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_); + OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_infer_); NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); } // Only have 1 output Tensor. diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h index ee21c7362..c50711588 100644 --- a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h @@ -35,11 +35,40 @@ class FASTDEPLOY_DECL RecognizerPreprocessor { size_t start_index, size_t end_index, const std::vector& indices); + /// Set static_shape_infer is true or not. When deploy PP-OCR + /// on hardware which can not support dynamic input shape very well, + /// like Huawei Ascned, static_shape_infer needs to to be true. + void SetStaticShapeInfer(bool static_shape_infer) + { static_shape_infer_ = static_shape_infer; } + /// Get static_shape_infer of the recognition preprocess + bool GetStaticShapeInfer() const { return static_shape_infer_; } + + /// Set mean value for the image normalization in recognition preprocess + void SetMean(std::vector mean) { mean_ = mean; } + /// Get mean value of the image normalization in recognition preprocess + std::vector GetMean() const { return mean_; } + + /// Set scale value for the image normalization in recognition preprocess + void SetScale(std::vector scale) { scale_ = scale; } + /// Get scale value of the image normalization in recognition preprocess + std::vector GetScale() const { return scale_; } + + /// Set is_scale for the image normalization in recognition preprocess + void SetIsScale(bool is_scale) { is_scale_ = is_scale; } + /// Get is_scale of the image normalization in recognition preprocess + bool GetIsScale() const { return is_scale_; } + + /// Set rec_image_shape for the recognition preprocess + void SetRecImageShape(std::vector rec_image_shape) + { rec_image_shape_ = rec_image_shape; } + /// Get rec_image_shape for the recognition preprocess + std::vector GetRecImageShape() const { return rec_image_shape_; } + std::vector rec_image_shape_ = {3, 48, 320}; std::vector mean_ = {0.5f, 0.5f, 0.5f}; std::vector scale_ = {0.5f, 0.5f, 0.5f}; bool is_scale_ = true; - bool static_shape_ = false; + bool static_shape_infer_ = false; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.h b/fastdeploy/vision/ocr/ppocr/recognizer.h index bba8a4447..60ffdcd10 100755 --- a/fastdeploy/vision/ocr/ppocr/recognizer.h +++ b/fastdeploy/vision/ocr/ppocr/recognizer.h @@ -67,11 +67,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { size_t start_index, size_t end_index, const std::vector& indices); - RecognizerPreprocessor preprocessor_; - RecognizerPostprocessor postprocessor_; + /// Get preprocessor reference of DBDetectorPreprocessor + virtual RecognizerPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference of DBDetectorPostprocessor + virtual RecognizerPostprocessor& GetPostprocessor() { + return postprocessor_; + } private: bool Initialize(); + RecognizerPreprocessor preprocessor_; + RecognizerPostprocessor postprocessor_; }; } // namespace ocr diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index a357547fd..41bb279a5 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -509,15 +509,15 @@ class RecognizerPreprocessor: return self._preprocessor.run(input_ims) @property - def static_shape(self): - return self._preprocessor.static_shape + def static_shape_infer(self): + return self._preprocessor.static_shape_infer - @static_shape.setter - def static_shape(self, value): + @static_shape_infer.setter + def static_shape_infer(self, value): assert isinstance( value, - bool), "The value to set `static_shape` must be type of bool." - self._preprocessor.static_shape = value + bool), "The value to set `static_shape_infer` must be type of bool." + self._preprocessor.static_shape_infer = value @property def is_scale(self): @@ -638,15 +638,15 @@ class Recognizer(FastDeployModel): self._model.postprocessor = value @property - def static_shape(self): - return self._model.preprocessor.static_shape + def static_shape_infer(self): + return self._model.preprocessor.static_shape_infer - @static_shape.setter - def static_shape(self, value): + @static_shape_infer.setter + def static_shape_infer(self, value): assert isinstance( value, - bool), "The value to set `static_shape` must be type of bool." - self._model.preprocessor.static_shape = value + bool), "The value to set `static_shape_infer` must be type of bool." + self._model.preprocessor.static_shape_infer = value @property def is_scale(self):