Add PPOCR example for ascend deploy

This commit is contained in:
yunyaoXYY
2022-12-28 12:56:53 +00:00
parent 471f0f62c8
commit b6903b0aa4
21 changed files with 235 additions and 74 deletions

View File

@@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3
# 昆仑芯XPU推理
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4
# 华为昇腾推理
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5
```
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:

View File

@@ -55,6 +55,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option);
auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option);
// Users could enable static shape infer for rec model when deploy PP-OCR on hardware
// which can not support dynamic shape infer well, like Huawei Ascend series.
// rec_model.GetPreprocessor().SetStaticShapeInfer(true);
assert(det_model.Initialized());
assert(cls_model.Initialized());
assert(rec_model.Initialized());
@@ -66,6 +70,9 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
// Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
// When inference batch size is set to -1, it means that the inference batch size
// of the cls and rec models will be the same as the number of boxes detected by the det model.
// When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1.
// ppocr_v2.SetClsBatchSize(1);
// ppocr_v2.SetRecBatchSize(1);
ppocr_v2.SetClsBatchSize(cls_batch_size);
ppocr_v2.SetRecBatchSize(rec_batch_size);
@@ -122,6 +129,8 @@ int main(int argc, char* argv[]) {
option.EnablePaddleToTrt();
} else if (flag == 4) {
option.UseKunlunXin();
} else if (flag == 5) {
option.UseAscend();
}
std::string det_model_dir = argv[1];

View File

@@ -36,6 +36,8 @@ python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt
# 昆仑芯XPU推理
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin
# 华为昇腾推理
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend
```
运行完成可视化结果如下图所示

View File

@@ -72,6 +72,10 @@ def build_option(args):
option.use_kunlunxin()
return option
if args.device.lower() == "ascend":
option.use_ascend()
return option
if args.backend.lower() == "trt":
assert args.device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
@@ -112,6 +116,8 @@ runtime_option = build_option(args)
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
# 需要把cls_batch_size和rec_batch_size都设置为1.
cls_batch_size = 1
rec_batch_size = 6
@@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 32, 10],
rec_model = fd.vision.ocr.Recognizer(
rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
# 需要使用下行代码, 来启用rec模型的静态shape推理.
# rec_model.preprocessor.static_shape_infer = True
# 创建PP-OCR串联3个模型其中cls_model可选如无需求可设置为None
ppocr_v2 = fd.vision.ocr.PPOCRv2(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)

View File

@@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3
# 昆仑芯XPU推理
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4
# 华为昇腾推理, 请用户在代码里正确开启Rec模型的静态shape推理并设置分类模型和识别模型的推理batch size为1.
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5
```
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:

View File

@@ -34,7 +34,7 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
auto rec_option = option;
// The cls and rec model can inference a batch of images now.
// User could initialize the inference batch size and set them after create PPOCR model.
// User could initialize the inference batch size and set them after create PP-OCR model.
int cls_batch_size = 1;
int rec_batch_size = 6;
@@ -56,6 +56,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option);
auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option);
// Users could enable static shape infer for rec model when deploy PP-OCR on hardware
// which can not support dynamic shape infer well, like Huawei Ascend series.
// rec_model.GetPreprocessor().SetStaticShapeInfer(true);
assert(det_model.Initialized());
assert(cls_model.Initialized());
assert(rec_model.Initialized());
@@ -67,6 +71,9 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
// Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
// When inference batch size is set to -1, it means that the inference batch size
// of the cls and rec models will be the same as the number of boxes detected by the det model.
// When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1.
// ppocr_v3.SetClsBatchSize(1);
// ppocr_v3.SetRecBatchSize(1);
ppocr_v3.SetClsBatchSize(cls_batch_size);
ppocr_v3.SetRecBatchSize(rec_batch_size);
@@ -123,6 +130,8 @@ int main(int argc, char* argv[]) {
option.EnablePaddleToTrt();
} else if (flag == 4) {
option.UseKunlunXin();
} else if (flag == 5) {
option.UseAscend();
}
std::string det_model_dir = argv[1];

View File

@@ -35,6 +35,8 @@ python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt
# 昆仑芯XPU推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin
# 华为昇腾推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend
```
运行完成可视化结果如下图所示

View File

@@ -72,6 +72,10 @@ def build_option(args):
option.use_kunlunxin()
return option
if args.device.lower() == "ascend":
option.use_ascend()
return option
if args.backend.lower() == "trt":
assert args.device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
@@ -112,6 +116,8 @@ runtime_option = build_option(args)
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
# 需要把cls_batch_size和rec_batch_size都设置为1.
cls_batch_size = 1
rec_batch_size = 6
@@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
rec_model = fd.vision.ocr.Recognizer(
rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
# 需要使用下行代码, 来启用rec模型的静态shape推理.
# rec_model.preprocessor.static_shape_infer = True
# 创建PP-OCR串联3个模型其中cls_model可选如无需求可设置为None
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)

View File

@@ -68,11 +68,20 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
std::vector<float>* cls_scores,
size_t start_index, size_t end_index);
ClassifierPreprocessor preprocessor_;
ClassifierPostprocessor postprocessor_;
/// Get preprocessor reference of ClassifierPreprocessor
virtual ClassifierPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of ClassifierPostprocessor
virtual ClassifierPostprocessor& GetPostprocessor() {
return postprocessor_;
}
private:
bool Initialize();
ClassifierPreprocessor preprocessor_;
ClassifierPostprocessor postprocessor_;
};
} // namespace ocr

View File

@@ -39,6 +39,12 @@ class FASTDEPLOY_DECL ClassifierPostprocessor {
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores,
size_t start_index, size_t total_size);
/// Set threshold for the classification postprocess, default is 0.9
void SetClsThresh(float cls_thresh) { cls_thresh_ = cls_thresh; }
/// Get threshold value of the classification postprocess.
float GetClsThresh() const { return cls_thresh_; }
float cls_thresh_ = 0.9;
};

View File

@@ -34,6 +34,27 @@ class FASTDEPLOY_DECL ClassifierPreprocessor {
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
size_t start_index, size_t end_index);
/// Set mean value for the image normalization in classification preprocess
void SetMean(std::vector<float> mean) { mean_ = mean; }
/// Get mean value of the image normalization in classification preprocess
std::vector<float> GetMean() const { return mean_; }
/// Set scale value for the image normalization in classification preprocess
void SetScale(std::vector<float> scale) { scale_ = scale; }
/// Get scale value of the image normalization in classification preprocess
std::vector<float> GetScale() const { return scale_; }
/// Set is_scale for the image normalization in classification preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in classification preprocess
bool GetIsScale() const { return is_scale_; }
/// Set cls_image_shape for the classification preprocess
void SetClsImageShape(std::vector<int> cls_image_shape)
{ cls_image_shape_ = cls_image_shape; }
/// Get cls_image_shape for the classification preprocess
std::vector<int> GetClsImageShape() const { return cls_image_shape_; }
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;

View File

@@ -61,11 +61,20 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results);
DBDetectorPreprocessor preprocessor_;
DBDetectorPostprocessor postprocessor_;
/// Get preprocessor reference of DBDetectorPreprocessor
virtual DBDetectorPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of DBDetectorPostprocessor
virtual DBDetectorPostprocessor& GetPostprocessor() {
return postprocessor_;
}
private:
bool Initialize();
DBDetectorPreprocessor preprocessor_;
DBDetectorPostprocessor postprocessor_;
};
} // namespace ocr

View File

@@ -36,6 +36,34 @@ class FASTDEPLOY_DECL DBDetectorPostprocessor {
std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int, 4>>& batch_det_img_info);
/// Set det_db_thresh for the detection postprocess, default is 0.3
void SetDetDBThresh(double det_db_thresh) { det_db_thresh_ = det_db_thresh; }
/// Get det_db_thresh of the detection postprocess
double GetDetDBThresh() const { return det_db_thresh_; }
/// Set det_db_box_thresh for the detection postprocess, default is 0.6
void SetDetDBBoxThresh(double det_db_box_thresh)
{ det_db_box_thresh_ = det_db_box_thresh; }
/// Get det_db_box_thresh of the detection postprocess
double GetDetDBBoxThresh() const { return det_db_box_thresh_; }
/// Set det_db_unclip_ratio for the detection postprocess, default is 1.5
void SetDetDBUnclipRatio(double det_db_unclip_ratio)
{ det_db_unclip_ratio_ = det_db_unclip_ratio; }
/// Get det_db_unclip_ratio_ of the detection postprocess
double GetDetDBUnclipRatio() const { return det_db_unclip_ratio_; }
/// Set det_db_score_mode for the detection postprocess, default is 'slow'
void SetDetDBScoreMode(std::string det_db_score_mode)
{ det_db_score_mode_ = det_db_score_mode; }
/// Get det_db_score_mode_ of the detection postprocess
std::string GetDetDBScoreMode() const { return det_db_score_mode_; }
/// Set use_dilation for the detection postprocess, default is fasle
void SetUseDilation(int use_dilation) { use_dilation_ = use_dilation; }
/// Get use_dilation of the detection postprocess
int GetUseDilation() const { return use_dilation_; }
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.6;
double det_db_unclip_ratio_ = 1.5;

View File

@@ -35,6 +35,26 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor {
std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
/// Set max_side_len for the detection preprocess, default is 960
void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; }
/// Get max_side_len of the detection preprocess
int GetMaxSideLen() const { return max_side_len_; }
/// Set mean value for the image normalization in detection preprocess
void SetMean(std::vector<float> mean) { mean_ = mean; }
/// Get mean value of the image normalization in detection preprocess
std::vector<float> GetMean() const { return mean_; }
/// Set scale value for the image normalization in detection preprocess
void SetScale(std::vector<float> scale) { scale_ = scale; }
/// Get scale value of the image normalization in detection preprocess
std::vector<float> GetScale() const { return scale_; }
/// Set is_scale for the image normalization in detection preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in detection preprocess
bool GetIsScale() const { return is_scale_; }
int max_side_len_ = 960;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};

View File

@@ -24,10 +24,10 @@ void BindPPOCRModel(pybind11::module& m) {
// DBDetector
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
.def(pybind11::init<>())
.def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
.def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_)
.def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen)
.def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean)
.def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale)
.def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale)
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
@@ -44,11 +44,12 @@ void BindPPOCRModel(pybind11::module& m) {
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
.def(pybind11::init<>())
.def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_)
.def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_)
.def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_)
.def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_)
.def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_)
.def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh)
.def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh)
.def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio)
.def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode)
.def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation)
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
std::vector<FDTensor>& inputs,
const std::vector<std::array<int, 4>>& batch_det_img_info) {
@@ -75,8 +76,8 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_)
.def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor)
.def("predict", [](vision::ocr::DBDetector& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
@@ -97,10 +98,10 @@ void BindPPOCRModel(pybind11::module& m) {
// Classifier
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_)
.def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape)
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean)
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale)
.def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale)
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
@@ -118,7 +119,7 @@ void BindPPOCRModel(pybind11::module& m) {
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_)
.def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh)
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<int> cls_labels;
@@ -144,8 +145,8 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_)
.def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor)
.def("predict", [](vision::ocr::Classifier& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
@@ -168,11 +169,11 @@ void BindPPOCRModel(pybind11::module& m) {
// Recognizer
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
.def(pybind11::init<>())
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
.def_readwrite("static_shape", &vision::ocr::RecognizerPreprocessor::static_shape_)
.def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer)
.def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape)
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean)
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale)
.def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale)
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
@@ -215,8 +216,8 @@ void BindPPOCRModel(pybind11::module& m) {
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_)
.def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor)
.def("predict", [](vision::ocr::Recognizer& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);

View File

@@ -23,14 +23,14 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
recognizer_->GetPreprocessor().rec_image_shape_[1] = 32;
}
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), recognizer_(rec_model) {
Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
recognizer_->GetPreprocessor().rec_image_shape_[1] = 32;
}
bool PPOCRv2::SetClsBatchSize(int cls_batch_size) {
@@ -134,7 +134,7 @@ bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
return false;
}else{
for (size_t i_img = start_index; i_img < end_index; ++i_img) {
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->GetPostprocessor().cls_thresh_) {
cv::rotate(image_list[i_img], image_list[i_img], 1);
}
}

View File

@@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, cls_model, rec_model) {
// The only difference between v2 and v3
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
recognizer_->GetPreprocessor().rec_image_shape_[1] = 48;
}
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
*
@@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, rec_model) {
// The only difference between v2 and v3
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
recognizer_->GetPreprocessor().rec_image_shape_[1] = 48;
}
};

View File

@@ -22,12 +22,12 @@ namespace vision {
namespace ocr {
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape, bool static_shape) {
const std::vector<int>& rec_image_shape, bool static_shape_infer) {
int img_h, img_w;
img_h = rec_image_shape[1];
img_w = rec_image_shape[2];
if (!static_shape) {
if (!static_shape_infer) {
img_w = int(img_h * max_wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
@@ -52,23 +52,6 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
}
}
void OcrRecognizerResizeImageOnAscend(FDMat* mat,
const std::vector<int>& rec_image_shape) {
int img_h, img_w;
img_h = rec_image_shape[1];
img_w = rec_image_shape[2];
if (mat->Width() >= img_w) {
Resize::Run(mat, img_w, img_h); // Reszie W to 320
} else {
Resize::Run(mat, mat->Width(), img_h);
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {0,0,0});
// Pad to 320
}
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
return Run(images, outputs, 0, images->size(), {});
}
@@ -101,7 +84,7 @@ bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
real_index = indices[i];
}
FDMat* mat = &(images->at(real_index));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_);
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_infer_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
}
// Only have 1 output Tensor.

View File

@@ -35,11 +35,40 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
size_t start_index, size_t end_index,
const std::vector<int>& indices);
/// Set static_shape_infer is true or not. When deploy PP-OCR
/// on hardware which can not support dynamic input shape very well,
/// like Huawei Ascned, static_shape_infer needs to to be true.
void SetStaticShapeInfer(bool static_shape_infer)
{ static_shape_infer_ = static_shape_infer; }
/// Get static_shape_infer of the recognition preprocess
bool GetStaticShapeInfer() const { return static_shape_infer_; }
/// Set mean value for the image normalization in recognition preprocess
void SetMean(std::vector<float> mean) { mean_ = mean; }
/// Get mean value of the image normalization in recognition preprocess
std::vector<float> GetMean() const { return mean_; }
/// Set scale value for the image normalization in recognition preprocess
void SetScale(std::vector<float> scale) { scale_ = scale; }
/// Get scale value of the image normalization in recognition preprocess
std::vector<float> GetScale() const { return scale_; }
/// Set is_scale for the image normalization in recognition preprocess
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
/// Get is_scale of the image normalization in recognition preprocess
bool GetIsScale() const { return is_scale_; }
/// Set rec_image_shape for the recognition preprocess
void SetRecImageShape(std::vector<int> rec_image_shape)
{ rec_image_shape_ = rec_image_shape; }
/// Get rec_image_shape for the recognition preprocess
std::vector<int> GetRecImageShape() const { return rec_image_shape_; }
std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
bool static_shape_ = false;
bool static_shape_infer_ = false;
};
} // namespace ocr

View File

@@ -67,11 +67,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
size_t start_index, size_t end_index,
const std::vector<int>& indices);
RecognizerPreprocessor preprocessor_;
RecognizerPostprocessor postprocessor_;
/// Get preprocessor reference of DBDetectorPreprocessor
virtual RecognizerPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of DBDetectorPostprocessor
virtual RecognizerPostprocessor& GetPostprocessor() {
return postprocessor_;
}
private:
bool Initialize();
RecognizerPreprocessor preprocessor_;
RecognizerPostprocessor postprocessor_;
};
} // namespace ocr

View File

@@ -509,15 +509,15 @@ class RecognizerPreprocessor:
return self._preprocessor.run(input_ims)
@property
def static_shape(self):
return self._preprocessor.static_shape
def static_shape_infer(self):
return self._preprocessor.static_shape_infer
@static_shape.setter
def static_shape(self, value):
@static_shape_infer.setter
def static_shape_infer(self, value):
assert isinstance(
value,
bool), "The value to set `static_shape` must be type of bool."
self._preprocessor.static_shape = value
bool), "The value to set `static_shape_infer` must be type of bool."
self._preprocessor.static_shape_infer = value
@property
def is_scale(self):
@@ -638,15 +638,15 @@ class Recognizer(FastDeployModel):
self._model.postprocessor = value
@property
def static_shape(self):
return self._model.preprocessor.static_shape
def static_shape_infer(self):
return self._model.preprocessor.static_shape_infer
@static_shape.setter
def static_shape(self, value):
@static_shape_infer.setter
def static_shape_infer(self, value):
assert isinstance(
value,
bool), "The value to set `static_shape` must be type of bool."
self._model.preprocessor.static_shape = value
bool), "The value to set `static_shape_infer` must be type of bool."
self._model.preprocessor.static_shape_infer = value
@property
def is_scale(self):