mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-17 22:21:48 +08:00
Add PPOCR example for ascend deploy
This commit is contained in:
@@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3
|
||||
# 昆仑芯XPU推理
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4
|
||||
# 华为昇腾推理
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5
|
||||
```
|
||||
|
||||
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
|
||||
|
@@ -55,6 +55,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
|
||||
auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option);
|
||||
auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option);
|
||||
|
||||
// Users could enable static shape infer for rec model when deploy PP-OCR on hardware
|
||||
// which can not support dynamic shape infer well, like Huawei Ascend series.
|
||||
// rec_model.GetPreprocessor().SetStaticShapeInfer(true);
|
||||
|
||||
assert(det_model.Initialized());
|
||||
assert(cls_model.Initialized());
|
||||
assert(rec_model.Initialized());
|
||||
@@ -65,9 +69,12 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
|
||||
|
||||
// Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
|
||||
// When inference batch size is set to -1, it means that the inference batch size
|
||||
// of the cls and rec models will be the same as the number of boxes detected by the det model.
|
||||
// of the cls and rec models will be the same as the number of boxes detected by the det model.
|
||||
// When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1.
|
||||
// ppocr_v2.SetClsBatchSize(1);
|
||||
// ppocr_v2.SetRecBatchSize(1);
|
||||
ppocr_v2.SetClsBatchSize(cls_batch_size);
|
||||
ppocr_v2.SetRecBatchSize(rec_batch_size);
|
||||
ppocr_v2.SetRecBatchSize(rec_batch_size);
|
||||
|
||||
if(!ppocr_v2.Initialized()){
|
||||
std::cerr << "Failed to initialize PP-OCR." << std::endl;
|
||||
@@ -122,6 +129,8 @@ int main(int argc, char* argv[]) {
|
||||
option.EnablePaddleToTrt();
|
||||
} else if (flag == 4) {
|
||||
option.UseKunlunXin();
|
||||
} else if (flag == 5) {
|
||||
option.UseAscend();
|
||||
}
|
||||
|
||||
std::string det_model_dir = argv[1];
|
||||
|
@@ -36,6 +36,8 @@ python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt
|
||||
# 昆仑芯XPU推理
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin
|
||||
# 华为昇腾推理
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
@@ -72,6 +72,10 @@ def build_option(args):
|
||||
option.use_kunlunxin()
|
||||
return option
|
||||
|
||||
if args.device.lower() == "ascend":
|
||||
option.use_ascend()
|
||||
return option
|
||||
|
||||
if args.backend.lower() == "trt":
|
||||
assert args.device.lower(
|
||||
) == "gpu", "TensorRT backend require inference on device GPU."
|
||||
@@ -112,6 +116,8 @@ runtime_option = build_option(args)
|
||||
|
||||
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
|
||||
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
|
||||
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
|
||||
# 需要把cls_batch_size和rec_batch_size都设置为1.
|
||||
cls_batch_size = 1
|
||||
rec_batch_size = 6
|
||||
|
||||
@@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 32, 10],
|
||||
rec_model = fd.vision.ocr.Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
|
||||
|
||||
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
|
||||
# 需要使用下行代码, 来启用rec模型的静态shape推理.
|
||||
# rec_model.preprocessor.static_shape_infer = True
|
||||
|
||||
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
|
||||
ppocr_v2 = fd.vision.ocr.PPOCRv2(
|
||||
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
|
||||
|
@@ -43,6 +43,8 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3
|
||||
# 昆仑芯XPU推理
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4
|
||||
# 华为昇腾推理, 请用户在代码里正确开启Rec模型的静态shape推理,并设置分类模型和识别模型的推理batch size为1.
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 5
|
||||
```
|
||||
|
||||
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
|
||||
|
@@ -34,7 +34,7 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
|
||||
auto rec_option = option;
|
||||
|
||||
// The cls and rec model can inference a batch of images now.
|
||||
// User could initialize the inference batch size and set them after create PPOCR model.
|
||||
// User could initialize the inference batch size and set them after create PP-OCR model.
|
||||
int cls_batch_size = 1;
|
||||
int rec_batch_size = 6;
|
||||
|
||||
@@ -56,6 +56,10 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
|
||||
auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option);
|
||||
auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option);
|
||||
|
||||
// Users could enable static shape infer for rec model when deploy PP-OCR on hardware
|
||||
// which can not support dynamic shape infer well, like Huawei Ascend series.
|
||||
// rec_model.GetPreprocessor().SetStaticShapeInfer(true);
|
||||
|
||||
assert(det_model.Initialized());
|
||||
assert(cls_model.Initialized());
|
||||
assert(rec_model.Initialized());
|
||||
@@ -66,9 +70,12 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model
|
||||
|
||||
// Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
|
||||
// When inference batch size is set to -1, it means that the inference batch size
|
||||
// of the cls and rec models will be the same as the number of boxes detected by the det model.
|
||||
// of the cls and rec models will be the same as the number of boxes detected by the det model.
|
||||
// When users enable static shape infer for rec model, the batch size of cls and rec model needs to be set to 1.
|
||||
// ppocr_v3.SetClsBatchSize(1);
|
||||
// ppocr_v3.SetRecBatchSize(1);
|
||||
ppocr_v3.SetClsBatchSize(cls_batch_size);
|
||||
ppocr_v3.SetRecBatchSize(rec_batch_size);
|
||||
ppocr_v3.SetRecBatchSize(rec_batch_size);
|
||||
|
||||
if(!ppocr_v3.Initialized()){
|
||||
std::cerr << "Failed to initialize PP-OCR." << std::endl;
|
||||
@@ -123,6 +130,8 @@ int main(int argc, char* argv[]) {
|
||||
option.EnablePaddleToTrt();
|
||||
} else if (flag == 4) {
|
||||
option.UseKunlunXin();
|
||||
} else if (flag == 5) {
|
||||
option.UseAscend();
|
||||
}
|
||||
|
||||
std::string det_model_dir = argv[1];
|
||||
|
@@ -35,6 +35,8 @@ python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt
|
||||
# 昆仑芯XPU推理
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device kunlunxin
|
||||
# 华为昇腾推理
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device ascend
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
@@ -72,6 +72,10 @@ def build_option(args):
|
||||
option.use_kunlunxin()
|
||||
return option
|
||||
|
||||
if args.device.lower() == "ascend":
|
||||
option.use_ascend()
|
||||
return option
|
||||
|
||||
if args.backend.lower() == "trt":
|
||||
assert args.device.lower(
|
||||
) == "gpu", "TensorRT backend require inference on device GPU."
|
||||
@@ -112,6 +116,8 @@ runtime_option = build_option(args)
|
||||
|
||||
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
|
||||
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
|
||||
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
|
||||
# 需要把cls_batch_size和rec_batch_size都设置为1.
|
||||
cls_batch_size = 1
|
||||
rec_batch_size = 6
|
||||
|
||||
@@ -144,6 +150,10 @@ rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
|
||||
rec_model = fd.vision.ocr.Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
|
||||
|
||||
# 当用户要把PP-OCR部署在对动态shape推理支持有限的设备上时,(例如华为昇腾)
|
||||
# 需要使用下行代码, 来启用rec模型的静态shape推理.
|
||||
# rec_model.preprocessor.static_shape_infer = True
|
||||
|
||||
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
|
||||
ppocr_v3 = fd.vision.ocr.PPOCRv3(
|
||||
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
|
||||
|
@@ -68,11 +68,20 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
|
||||
std::vector<float>* cls_scores,
|
||||
size_t start_index, size_t end_index);
|
||||
|
||||
ClassifierPreprocessor preprocessor_;
|
||||
ClassifierPostprocessor postprocessor_;
|
||||
/// Get preprocessor reference of ClassifierPreprocessor
|
||||
virtual ClassifierPreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
/// Get postprocessor reference of ClassifierPostprocessor
|
||||
virtual ClassifierPostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool Initialize();
|
||||
ClassifierPreprocessor preprocessor_;
|
||||
ClassifierPostprocessor postprocessor_;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
|
@@ -39,6 +39,12 @@ class FASTDEPLOY_DECL ClassifierPostprocessor {
|
||||
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores,
|
||||
size_t start_index, size_t total_size);
|
||||
|
||||
/// Set threshold for the classification postprocess, default is 0.9
|
||||
void SetClsThresh(float cls_thresh) { cls_thresh_ = cls_thresh; }
|
||||
|
||||
/// Get threshold value of the classification postprocess.
|
||||
float GetClsThresh() const { return cls_thresh_; }
|
||||
|
||||
float cls_thresh_ = 0.9;
|
||||
};
|
||||
|
||||
|
@@ -34,6 +34,27 @@ class FASTDEPLOY_DECL ClassifierPreprocessor {
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
|
||||
size_t start_index, size_t end_index);
|
||||
|
||||
/// Set mean value for the image normalization in classification preprocess
|
||||
void SetMean(std::vector<float> mean) { mean_ = mean; }
|
||||
/// Get mean value of the image normalization in classification preprocess
|
||||
std::vector<float> GetMean() const { return mean_; }
|
||||
|
||||
/// Set scale value for the image normalization in classification preprocess
|
||||
void SetScale(std::vector<float> scale) { scale_ = scale; }
|
||||
/// Get scale value of the image normalization in classification preprocess
|
||||
std::vector<float> GetScale() const { return scale_; }
|
||||
|
||||
/// Set is_scale for the image normalization in classification preprocess
|
||||
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
|
||||
/// Get is_scale of the image normalization in classification preprocess
|
||||
bool GetIsScale() const { return is_scale_; }
|
||||
|
||||
/// Set cls_image_shape for the classification preprocess
|
||||
void SetClsImageShape(std::vector<int> cls_image_shape)
|
||||
{ cls_image_shape_ = cls_image_shape; }
|
||||
/// Get cls_image_shape for the classification preprocess
|
||||
std::vector<int> GetClsImageShape() const { return cls_image_shape_; }
|
||||
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
|
||||
bool is_scale_ = true;
|
||||
|
@@ -61,11 +61,20 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
|
||||
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||
std::vector<std::vector<std::array<int, 8>>>* det_results);
|
||||
|
||||
DBDetectorPreprocessor preprocessor_;
|
||||
DBDetectorPostprocessor postprocessor_;
|
||||
/// Get preprocessor reference of DBDetectorPreprocessor
|
||||
virtual DBDetectorPreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
/// Get postprocessor reference of DBDetectorPostprocessor
|
||||
virtual DBDetectorPostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool Initialize();
|
||||
DBDetectorPreprocessor preprocessor_;
|
||||
DBDetectorPostprocessor postprocessor_;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
|
@@ -36,6 +36,34 @@ class FASTDEPLOY_DECL DBDetectorPostprocessor {
|
||||
std::vector<std::vector<std::array<int, 8>>>* results,
|
||||
const std::vector<std::array<int, 4>>& batch_det_img_info);
|
||||
|
||||
/// Set det_db_thresh for the detection postprocess, default is 0.3
|
||||
void SetDetDBThresh(double det_db_thresh) { det_db_thresh_ = det_db_thresh; }
|
||||
/// Get det_db_thresh of the detection postprocess
|
||||
double GetDetDBThresh() const { return det_db_thresh_; }
|
||||
|
||||
/// Set det_db_box_thresh for the detection postprocess, default is 0.6
|
||||
void SetDetDBBoxThresh(double det_db_box_thresh)
|
||||
{ det_db_box_thresh_ = det_db_box_thresh; }
|
||||
/// Get det_db_box_thresh of the detection postprocess
|
||||
double GetDetDBBoxThresh() const { return det_db_box_thresh_; }
|
||||
|
||||
/// Set det_db_unclip_ratio for the detection postprocess, default is 1.5
|
||||
void SetDetDBUnclipRatio(double det_db_unclip_ratio)
|
||||
{ det_db_unclip_ratio_ = det_db_unclip_ratio; }
|
||||
/// Get det_db_unclip_ratio_ of the detection postprocess
|
||||
double GetDetDBUnclipRatio() const { return det_db_unclip_ratio_; }
|
||||
|
||||
/// Set det_db_score_mode for the detection postprocess, default is 'slow'
|
||||
void SetDetDBScoreMode(std::string det_db_score_mode)
|
||||
{ det_db_score_mode_ = det_db_score_mode; }
|
||||
/// Get det_db_score_mode_ of the detection postprocess
|
||||
std::string GetDetDBScoreMode() const { return det_db_score_mode_; }
|
||||
|
||||
/// Set use_dilation for the detection postprocess, default is fasle
|
||||
void SetUseDilation(int use_dilation) { use_dilation_ = use_dilation; }
|
||||
/// Get use_dilation of the detection postprocess
|
||||
int GetUseDilation() const { return use_dilation_; }
|
||||
|
||||
double det_db_thresh_ = 0.3;
|
||||
double det_db_box_thresh_ = 0.6;
|
||||
double det_db_unclip_ratio_ = 1.5;
|
||||
|
@@ -35,6 +35,26 @@ class FASTDEPLOY_DECL DBDetectorPreprocessor {
|
||||
std::vector<FDTensor>* outputs,
|
||||
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
|
||||
|
||||
/// Set max_side_len for the detection preprocess, default is 960
|
||||
void SetMaxSideLen(int max_side_len) { max_side_len_ = max_side_len; }
|
||||
/// Get max_side_len of the detection preprocess
|
||||
int GetMaxSideLen() const { return max_side_len_; }
|
||||
|
||||
/// Set mean value for the image normalization in detection preprocess
|
||||
void SetMean(std::vector<float> mean) { mean_ = mean; }
|
||||
/// Get mean value of the image normalization in detection preprocess
|
||||
std::vector<float> GetMean() const { return mean_; }
|
||||
|
||||
/// Set scale value for the image normalization in detection preprocess
|
||||
void SetScale(std::vector<float> scale) { scale_ = scale; }
|
||||
/// Get scale value of the image normalization in detection preprocess
|
||||
std::vector<float> GetScale() const { return scale_; }
|
||||
|
||||
/// Set is_scale for the image normalization in detection preprocess
|
||||
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
|
||||
/// Get is_scale of the image normalization in detection preprocess
|
||||
bool GetIsScale() const { return is_scale_; }
|
||||
|
||||
int max_side_len_ = 960;
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
|
||||
|
@@ -24,10 +24,10 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
// DBDetector
|
||||
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
|
||||
.def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_)
|
||||
.def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_)
|
||||
.def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_)
|
||||
.def_property("max_side_len", &vision::ocr::DBDetectorPreprocessor::GetMaxSideLen, &vision::ocr::DBDetectorPreprocessor::SetMaxSideLen)
|
||||
.def_property("mean", &vision::ocr::DBDetectorPreprocessor::GetMean, &vision::ocr::DBDetectorPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::DBDetectorPreprocessor::GetScale, &vision::ocr::DBDetectorPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::DBDetectorPreprocessor::GetIsScale, &vision::ocr::DBDetectorPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
@@ -44,11 +44,12 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
|
||||
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_)
|
||||
.def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_)
|
||||
.def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_)
|
||||
.def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_)
|
||||
.def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_)
|
||||
.def_property("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBThresh)
|
||||
.def_property("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::GetDetDBBoxThresh, &vision::ocr::DBDetectorPostprocessor::SetDetDBBoxThresh)
|
||||
.def_property("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::GetDetDBUnclipRatio, &vision::ocr::DBDetectorPostprocessor::SetDetDBUnclipRatio)
|
||||
.def_property("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::GetDetDBScoreMode, &vision::ocr::DBDetectorPostprocessor::SetDetDBScoreMode)
|
||||
.def_property("use_dilation", &vision::ocr::DBDetectorPostprocessor::GetUseDilation, &vision::ocr::DBDetectorPostprocessor::SetUseDilation)
|
||||
|
||||
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs,
|
||||
const std::vector<std::array<int, 4>>& batch_det_img_info) {
|
||||
@@ -75,8 +76,8 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
|
||||
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_)
|
||||
.def_property_readonly("preprocessor", &vision::ocr::DBDetector::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::DBDetector::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::DBDetector& self,
|
||||
pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
@@ -97,10 +98,10 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
// Classifier
|
||||
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
|
||||
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
|
||||
.def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_)
|
||||
.def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_)
|
||||
.def_property("cls_image_shape", &vision::ocr::ClassifierPreprocessor::GetClsImageShape, &vision::ocr::ClassifierPreprocessor::SetClsImageShape)
|
||||
.def_property("mean", &vision::ocr::ClassifierPreprocessor::GetMean, &vision::ocr::ClassifierPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::ClassifierPreprocessor::GetScale, &vision::ocr::ClassifierPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::ClassifierPreprocessor::GetIsScale, &vision::ocr::ClassifierPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
@@ -118,7 +119,7 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
|
||||
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_)
|
||||
.def_property("cls_thresh", &vision::ocr::ClassifierPostprocessor::GetClsThresh, &vision::ocr::ClassifierPostprocessor::SetClsThresh)
|
||||
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<int> cls_labels;
|
||||
@@ -144,8 +145,8 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
|
||||
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_)
|
||||
.def_property_readonly("preprocessor", &vision::ocr::Classifier::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::Classifier::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::Classifier& self,
|
||||
pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
@@ -168,11 +169,11 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
// Recognizer
|
||||
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
|
||||
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
|
||||
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
|
||||
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
|
||||
.def_readwrite("static_shape", &vision::ocr::RecognizerPreprocessor::static_shape_)
|
||||
.def_property("static_shape_infer", &vision::ocr::RecognizerPreprocessor::GetStaticShapeInfer, &vision::ocr::RecognizerPreprocessor::SetStaticShapeInfer)
|
||||
.def_property("rec_image_shape", &vision::ocr::RecognizerPreprocessor::GetRecImageShape, &vision::ocr::RecognizerPreprocessor::SetRecImageShape)
|
||||
.def_property("mean", &vision::ocr::RecognizerPreprocessor::GetMean, &vision::ocr::RecognizerPreprocessor::SetMean)
|
||||
.def_property("scale", &vision::ocr::RecognizerPreprocessor::GetScale, &vision::ocr::RecognizerPreprocessor::SetScale)
|
||||
.def_property("is_scale", &vision::ocr::RecognizerPreprocessor::GetIsScale, &vision::ocr::RecognizerPreprocessor::SetIsScale)
|
||||
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
@@ -215,8 +216,8 @@ void BindPPOCRModel(pybind11::module& m) {
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
|
||||
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_)
|
||||
.def_property_readonly("preprocessor", &vision::ocr::Recognizer::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor)
|
||||
.def("predict", [](vision::ocr::Recognizer& self,
|
||||
pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
|
@@ -23,14 +23,14 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
|
||||
Initialized();
|
||||
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
|
||||
recognizer_->GetPreprocessor().rec_image_shape_[1] = 32;
|
||||
}
|
||||
|
||||
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: detector_(det_model), recognizer_(rec_model) {
|
||||
Initialized();
|
||||
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
|
||||
recognizer_->GetPreprocessor().rec_image_shape_[1] = 32;
|
||||
}
|
||||
|
||||
bool PPOCRv2::SetClsBatchSize(int cls_batch_size) {
|
||||
@@ -134,7 +134,7 @@ bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
return false;
|
||||
}else{
|
||||
for (size_t i_img = start_index; i_img < end_index; ++i_img) {
|
||||
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
|
||||
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->GetPostprocessor().cls_thresh_) {
|
||||
cv::rotate(image_list[i_img], image_list[i_img], 1);
|
||||
}
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: PPOCRv2(det_model, cls_model, rec_model) {
|
||||
// The only difference between v2 and v3
|
||||
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
|
||||
recognizer_->GetPreprocessor().rec_image_shape_[1] = 48;
|
||||
}
|
||||
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
|
||||
*
|
||||
@@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: PPOCRv2(det_model, rec_model) {
|
||||
// The only difference between v2 and v3
|
||||
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
|
||||
recognizer_->GetPreprocessor().rec_image_shape_[1] = 48;
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -22,12 +22,12 @@ namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
|
||||
const std::vector<int>& rec_image_shape, bool static_shape) {
|
||||
const std::vector<int>& rec_image_shape, bool static_shape_infer) {
|
||||
int img_h, img_w;
|
||||
img_h = rec_image_shape[1];
|
||||
img_w = rec_image_shape[2];
|
||||
|
||||
if (!static_shape) {
|
||||
if (!static_shape_infer) {
|
||||
|
||||
img_w = int(img_h * max_wh_ratio);
|
||||
float ratio = float(mat->Width()) / float(mat->Height());
|
||||
@@ -52,23 +52,6 @@ void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
|
||||
}
|
||||
}
|
||||
|
||||
void OcrRecognizerResizeImageOnAscend(FDMat* mat,
|
||||
const std::vector<int>& rec_image_shape) {
|
||||
|
||||
int img_h, img_w;
|
||||
img_h = rec_image_shape[1];
|
||||
img_w = rec_image_shape[2];
|
||||
|
||||
if (mat->Width() >= img_w) {
|
||||
Resize::Run(mat, img_w, img_h); // Reszie W to 320
|
||||
} else {
|
||||
Resize::Run(mat, mat->Width(), img_h);
|
||||
Pad::Run(mat, 0, 0, 0, int(img_w - mat->Width()), {0,0,0});
|
||||
// Pad to 320
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||
return Run(images, outputs, 0, images->size(), {});
|
||||
}
|
||||
@@ -101,7 +84,7 @@ bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
|
||||
real_index = indices[i];
|
||||
}
|
||||
FDMat* mat = &(images->at(real_index));
|
||||
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_);
|
||||
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_, static_shape_infer_);
|
||||
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
|
||||
}
|
||||
// Only have 1 output Tensor.
|
||||
|
@@ -35,11 +35,40 @@ class FASTDEPLOY_DECL RecognizerPreprocessor {
|
||||
size_t start_index, size_t end_index,
|
||||
const std::vector<int>& indices);
|
||||
|
||||
/// Set static_shape_infer is true or not. When deploy PP-OCR
|
||||
/// on hardware which can not support dynamic input shape very well,
|
||||
/// like Huawei Ascned, static_shape_infer needs to to be true.
|
||||
void SetStaticShapeInfer(bool static_shape_infer)
|
||||
{ static_shape_infer_ = static_shape_infer; }
|
||||
/// Get static_shape_infer of the recognition preprocess
|
||||
bool GetStaticShapeInfer() const { return static_shape_infer_; }
|
||||
|
||||
/// Set mean value for the image normalization in recognition preprocess
|
||||
void SetMean(std::vector<float> mean) { mean_ = mean; }
|
||||
/// Get mean value of the image normalization in recognition preprocess
|
||||
std::vector<float> GetMean() const { return mean_; }
|
||||
|
||||
/// Set scale value for the image normalization in recognition preprocess
|
||||
void SetScale(std::vector<float> scale) { scale_ = scale; }
|
||||
/// Get scale value of the image normalization in recognition preprocess
|
||||
std::vector<float> GetScale() const { return scale_; }
|
||||
|
||||
/// Set is_scale for the image normalization in recognition preprocess
|
||||
void SetIsScale(bool is_scale) { is_scale_ = is_scale; }
|
||||
/// Get is_scale of the image normalization in recognition preprocess
|
||||
bool GetIsScale() const { return is_scale_; }
|
||||
|
||||
/// Set rec_image_shape for the recognition preprocess
|
||||
void SetRecImageShape(std::vector<int> rec_image_shape)
|
||||
{ rec_image_shape_ = rec_image_shape; }
|
||||
/// Get rec_image_shape for the recognition preprocess
|
||||
std::vector<int> GetRecImageShape() const { return rec_image_shape_; }
|
||||
|
||||
std::vector<int> rec_image_shape_ = {3, 48, 320};
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
|
||||
bool is_scale_ = true;
|
||||
bool static_shape_ = false;
|
||||
bool static_shape_infer_ = false;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
|
@@ -67,11 +67,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
|
||||
size_t start_index, size_t end_index,
|
||||
const std::vector<int>& indices);
|
||||
|
||||
RecognizerPreprocessor preprocessor_;
|
||||
RecognizerPostprocessor postprocessor_;
|
||||
/// Get preprocessor reference of DBDetectorPreprocessor
|
||||
virtual RecognizerPreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
/// Get postprocessor reference of DBDetectorPostprocessor
|
||||
virtual RecognizerPostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool Initialize();
|
||||
RecognizerPreprocessor preprocessor_;
|
||||
RecognizerPostprocessor postprocessor_;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
|
@@ -509,15 +509,15 @@ class RecognizerPreprocessor:
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
@property
|
||||
def static_shape(self):
|
||||
return self._preprocessor.static_shape
|
||||
def static_shape_infer(self):
|
||||
return self._preprocessor.static_shape_infer
|
||||
|
||||
@static_shape.setter
|
||||
def static_shape(self, value):
|
||||
@static_shape_infer.setter
|
||||
def static_shape_infer(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `static_shape` must be type of bool."
|
||||
self._preprocessor.static_shape = value
|
||||
bool), "The value to set `static_shape_infer` must be type of bool."
|
||||
self._preprocessor.static_shape_infer = value
|
||||
|
||||
@property
|
||||
def is_scale(self):
|
||||
@@ -638,15 +638,15 @@ class Recognizer(FastDeployModel):
|
||||
self._model.postprocessor = value
|
||||
|
||||
@property
|
||||
def static_shape(self):
|
||||
return self._model.preprocessor.static_shape
|
||||
def static_shape_infer(self):
|
||||
return self._model.preprocessor.static_shape_infer
|
||||
|
||||
@static_shape.setter
|
||||
def static_shape(self, value):
|
||||
@static_shape_infer.setter
|
||||
def static_shape_infer(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `static_shape` must be type of bool."
|
||||
self._model.preprocessor.static_shape = value
|
||||
bool), "The value to set `static_shape_infer` must be type of bool."
|
||||
self._model.preprocessor.static_shape_infer = value
|
||||
|
||||
@property
|
||||
def is_scale(self):
|
||||
|
Reference in New Issue
Block a user