diff --git a/benchmark/cpp/CMakeLists.txt b/benchmark/cpp/CMakeLists.txt index 41373e1a5..dc0599f4f 100755 --- a/benchmark/cpp/CMakeLists.txt +++ b/benchmark/cpp/CMakeLists.txt @@ -23,6 +23,7 @@ add_executable(benchmark_ppocr_det ${PROJECT_SOURCE_DIR}/benchmark_ppocr_det.cc) add_executable(benchmark_ppocr_cls ${PROJECT_SOURCE_DIR}/benchmark_ppocr_cls.cc) add_executable(benchmark_ppocr_rec ${PROJECT_SOURCE_DIR}/benchmark_ppocr_rec.cc) add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_structurev2_table.cc) +add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc) add_executable(benchmark_ppyoloe_r ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r.cc) add_executable(benchmark_ppyoloe_r_sophgo ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r_sophgo.cc) add_executable(benchmark_ppyolo ${PROJECT_SOURCE_DIR}/benchmark_ppyolo.cc) @@ -57,6 +58,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags pthread) + target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags pthread) @@ -88,6 +90,7 @@ else() target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags) + target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags) diff --git a/benchmark/cpp/benchmark_structurev2_layout.cc b/benchmark/cpp/benchmark_structurev2_layout.cc new file mode 100644 index 000000000..c88c025c8 --- /dev/null +++ b/benchmark/cpp/benchmark_structurev2_layout.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" +#include "macros.h" +#include "option.h" + +namespace vision = fastdeploy::vision; +namespace benchmark = fastdeploy::benchmark; + +int main(int argc, char* argv[]) { +#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION) + // Initialization + auto option = fastdeploy::RuntimeOption(); + if (!CreateRuntimeOption(&option, argc, argv, true)) { + return -1; + } + auto im = cv::imread(FLAGS_image); + std::unordered_map config_info; + benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path, + &config_info); + std::string model_name, params_name, config_name; + auto model_format = fastdeploy::ModelFormat::PADDLE; + if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name, + &model_format, config_info, false)) { + return -1; + } + + auto model_file = FLAGS_model + sep + model_name; + auto params_file = FLAGS_model + sep + params_name; + if (config_info["backend"] == "paddle_trt") { + option.paddle_infer_option.collect_trt_shape = true; + } + if (config_info["backend"] == "paddle_trt" || + config_info["backend"] == "trt") { + option.trt_option.SetShape("image", {1, 3, 800, 608}, {1, 3, 800, 608}, + {1, 3, 800, 608}); + } + auto layout_model = vision::ocr::StructureV2Layout(model_file, params_file, + option, model_format); + // 5 for publaynet, 10 for cdla + layout_model.GetPostprocessor().SetNumClass(5); + + vision::DetectionResult res; + if (config_info["precision_compare"] == "true") { + // Run once at least + layout_model.Predict(im, &res); + // 1. Test result diff + std::cout << "=============== Test result diff =================\n"; + // Save result to -> disk. + std::string layout_result_path = "layout_result.txt"; + benchmark::ResultManager::SaveDetectionResult(res, layout_result_path); + // Load result from <- disk. + vision::DetectionResult res_loaded; + benchmark::ResultManager::LoadDetectionResult(&res_loaded, + layout_result_path); + // Calculate diff between two results. + auto det_diff = + benchmark::ResultManager::CalculateDiffStatis(res, res_loaded); + std::cout << "Boxes diff: mean=" << det_diff.boxes.mean + << ", max=" << det_diff.boxes.max + << ", min=" << det_diff.boxes.min << std::endl; + std::cout << "Label_ids diff: mean=" << det_diff.labels.mean + << ", max=" << det_diff.labels.max + << ", min=" << det_diff.labels.min << std::endl; + } + // Run profiling + BENCHMARK_MODEL(layout_model, layout_model.Predict(im, &res)) + std::vector labels = {"text", "title", "list", "table", + "figure"}; + if (layout_model.GetPostprocessor().GetNumClass() == 10) { + labels = {"text", "title", "figure", "figure_caption", + "table", "table_caption", "header", "footer", + "reference", "equation"}; + } + auto vis_im = + vision::VisDetection(im, res, labels, 0.3, 2, .5f, {255, 0, 0}, 2); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +#endif + return 0; +} \ No newline at end of file diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/CMakeLists.txt b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/CMakeLists.txt index ac0101c93..43d199d83 100644 --- a/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/CMakeLists.txt +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/CMakeLists.txt @@ -38,3 +38,8 @@ target_link_libraries(infer_rec ${FASTDEPLOY_LIBS}) add_executable(infer_structurev2_table ${PROJECT_SOURCE_DIR}/infer_structurev2_table.cc) # 添加FastDeploy库依赖 target_link_libraries(infer_structurev2_table ${FASTDEPLOY_LIBS}) + +# Only Layout +add_executable(infer_structurev2_layout ${PROJECT_SOURCE_DIR}/infer_structurev2_layout.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_structurev2_layout ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/README.md b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/README.md index 17332de19..b0e735e72 100644 --- a/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/README.md +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/README.md @@ -46,12 +46,18 @@ tar -xvf ch_PP-OCRv3_rec_infer.tar # 下载PPStructureV2表格识别模型 wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar +# 下载PP-StructureV2版面分析模型 +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar +tar -xvf picodet_lcnet_x1_0_fgd_layout_infer.tar # 下载预测图片与字典文件 wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/layout.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt # 运行部署示例 # 在CPU上使用Paddle Inference推理 @@ -71,7 +77,7 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/t # 在GPU上使用Nvidia TensorRT推理 ./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 7 -# 同时, FastDeploy提供文字检测,文字分类,文字识别三个模型的单独推理, +# 同时, FastDeploy提供文字检测,文字分类,文字识别,表格识别,版面分析等模型的单独推理, # 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.cc来配置自定义硬件与推理后端. # 在CPU上,单独使用文字检测模型部署 @@ -85,6 +91,9 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/t # 在CPU上,单独使用表格识别模型部署 ./infer_structurev2_table ./ch_ppstructure_mobile_v2.0_SLANet_infer ./table_structure_dict_ch.txt ./table.jpg 0 + +# 在CPU上,单独使用版面分析模型部署 +./infer_structurev2_layout ./picodet_lcnet_x1_0_fgd_layout_infer ./layout.jpg 0 ``` 运行完成可视化结果如下图所示 diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/infer_structurev2_layout.cc b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/infer_structurev2_layout.cc new file mode 100644 index 000000000..25ed739dd --- /dev/null +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/cpp/infer_structurev2_layout.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void InitAndInfer(const std::string &layout_model_dir, + const std::string &image_file, + const fastdeploy::RuntimeOption &option) { + auto layout_model_file = layout_model_dir + sep + "model.pdmodel"; + auto layout_params_file = layout_model_dir + sep + "model.pdiparams"; + + auto layout_model = fastdeploy::vision::ocr::StructureV2Layout( + layout_model_file, layout_params_file, option); + + if (!layout_model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + // 5 for publaynet, 10 for cdla + layout_model.GetPostprocessor().SetNumClass(5); + fastdeploy::vision::DetectionResult res; + if (!layout_model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + std::vector labels = {"text", "title", "list", "table", + "figure"}; + if (layout_model.GetPostprocessor().GetNumClass() == 10) { + labels = {"text", "title", "figure", "figure_caption", + "table", "table_caption", "header", "footer", + "reference", "equation"}; + } + auto vis_im = fastdeploy::vision::VisDetection(im, res, labels, 0.3, 2, .5f, + {255, 0, 0}, 2); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char *argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/layout_model path/to/image " + "run_option, " + "e.g ./infer_structurev2_layout picodet_lcnet_x1_0_fgd_layout_infer " + "layout.png 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu;." + << std::endl; + return -1; + } + + fastdeploy::RuntimeOption option; + int flag = std::atoi(argv[3]); + + if (flag == 0) { + option.UseCpu(); + } else if (flag == 1) { + option.UseGpu(); + } + + std::string layout_model_dir = argv[1]; + std::string image_file = argv[2]; + InitAndInfer(layout_model_dir, image_file, option); + return 0; +} diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/python/README.md b/examples/vision/ocr/PP-OCR/cpu-gpu/python/README.md index 60e8dd0c7..29e396283 100644 --- a/examples/vision/ocr/PP-OCR/cpu-gpu/python/README.md +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/python/README.md @@ -39,12 +39,18 @@ tar -xvf ch_PP-OCRv3_rec_infer.tar # 下载PPStructureV2表格识别模型 wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar +# 下载PP-StructureV2版面分析模型 +wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar +tar -xvf picodet_lcnet_x1_0_fgd_layout_infer.tar # 下载预测图片与字典文件 wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/layout.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt +wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt # 运行部署示例 # 在CPU上使用Paddle Inference推理 @@ -64,7 +70,7 @@ python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2 # 在GPU上使用Nvidia TensorRT推理 python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt -# 同时, FastDeploy提供文字检测,文字分类,文字识别三个模型的单独推理, +# 同时, FastDeploy提供文字检测,文字分类,文字识别,表格识别,版面分析等模型的单独推理, # 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.py来配置自定义硬件与推理后端. # 在CPU上,单独使用文字检测模型部署 @@ -76,8 +82,11 @@ python infer_cls.py --cls_model ch_ppocr_mobile_v2.0_cls_infer --image 12.jpg -- # 在CPU上,单独使用文字识别模型部署 python infer_rec.py --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu -# 在CPU上,单独使用文字识别模型部署 +# 在CPU上,单独使用表格识别模型部署 python infer_structurev2_table.py --table_model ./ch_ppstructure_mobile_v2.0_SLANet_infer --table_char_dict_path ./table_structure_dict_ch.txt --image table.jpg --device cpu + +# 在CPU上,单独使用版面分析模型部署 +python infer_structurev2_layout.py --layout_model ./picodet_lcnet_x1_0_fgd_layout_infer --image layout.jpg --device cpu ``` 运行完成可视化结果如下图所示 diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_layout.py b/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_layout.py new file mode 100644 index 000000000..a68969697 --- /dev/null +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_layout.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--layout_model", + required=True, + help="Path of Layout detection model of PP-StructureV2.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--device_id", + type=int, + default=0, + help="Define which GPU card used to run model.") + + return parser.parse_args() + + +def build_option(args): + + layout_option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + layout_option.use_gpu(args.device_id) + + return layout_option + + +args = parse_arguments() + +layout_model_file = os.path.join(args.layout_model, "model.pdmodel") +layout_params_file = os.path.join(args.layout_model, "model.pdiparams") + +# Set the runtime option +layout_option = build_option(args) + +# Create the table_model +layout_model = fd.vision.ocr.StructureV2Layout( + layout_model_file, layout_params_file, layout_option) + +layout_model.postprocessor.num_class = 5 + +# Read the image +im = cv2.imread(args.image) + +# Predict and return the results +result = layout_model.predict(im) + +print(result) + +# Visualize the results +labels = ["text", "title", "list", "table", "figure"] +if layout_model.postprocessor.num_class == 10: + labels = [ + "text", "title", "figure", "figure_caption", "table", "table_caption", + "header", "footer", "reference", "equation" + ] + +vis_im = fd.vision.vis_detection( + im, + result, + labels, + score_threshold=0.5, + font_color=[255, 0, 0], + font_thickness=2) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_table.py b/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_table.py index 45344d503..21650265f 100755 --- a/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_table.py +++ b/examples/vision/ocr/PP-OCR/cpu-gpu/python/infer_structurev2_table.py @@ -23,7 +23,7 @@ def parse_arguments(): parser.add_argument( "--table_model", required=True, - help="Path of Table recognition model of PPOCR.") + help="Path of Table recognition model of PP-StructureV2.") parser.add_argument( "--table_char_dict_path", type=str, diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 3800b56f2..0e8f7a9f6 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -54,9 +54,11 @@ #include "fastdeploy/vision/ocr/ppocr/classifier.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/structurev2_table.h" +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/ppstructurev2_table.h" +#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" #include "fastdeploy/vision/segmentation/ppseg/model.h" diff --git a/fastdeploy/vision/common/processors/proc_lib.cc b/fastdeploy/vision/common/processors/proc_lib.cc index a09ed0ab5..6d518657a 100644 --- a/fastdeploy/vision/common/processors/proc_lib.cc +++ b/fastdeploy/vision/common/processors/proc_lib.cc @@ -33,6 +33,9 @@ std::ostream& operator<<(std::ostream& out, const ProcLib& p) { case ProcLib::CUDA: out << "ProcLib::CUDA"; break; + case ProcLib::CVCUDA: + out << "ProcLib::CVCUDA"; + break; default: FDASSERT(false, "Unknow type of ProcLib."); } diff --git a/fastdeploy/vision/common/result.cc b/fastdeploy/vision/common/result.cc index a52cc95f6..05bc75ba5 100644 --- a/fastdeploy/vision/common/result.cc +++ b/fastdeploy/vision/common/result.cc @@ -153,10 +153,16 @@ std::string DetectionResult::Str() { std::string out; if (!contain_masks) { out = "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]\n"; + if (!rotated_boxes.empty()) { + out = "DetectionResult: [x1, y1, x2, y2, x3, y3, x4, y4, score, label_id]\n"; + } } else { out = "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id, " "mask_shape]\n"; + if (!rotated_boxes.empty()) { + out = "DetectionResult: [x1, y1, x2, y2, x3, y3, x4, y4, score, label_id, mask_shape]\n"; + } } for (size_t i = 0; i < boxes.size(); ++i) { out = out + std::to_string(boxes[i][0]) + "," + diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc index 019a11f91..243d93e26 100644 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -304,6 +304,7 @@ void BindPPOCRModel(pybind11::module& m) { &vision::ocr::Recognizer::GetPreprocessor) .def_property_readonly("postprocessor", &vision::ocr::Recognizer::GetPostprocessor) + .def("clone", [](vision::ocr::Recognizer& self) { return self.Clone(); }) .def("predict", [](vision::ocr::Recognizer& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); @@ -360,7 +361,7 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &boxes, &structure_list, batch_det_img_info)) { throw std::runtime_error( - "Failed to preprocess the input data in " + "Failed to postprocess the input data in " "StructureV2TablePostprocessor."); } return std::make_pair(boxes, structure_list); @@ -377,7 +378,7 @@ void BindPPOCRModel(pybind11::module& m) { if (!self.Run(inputs, &boxes, &structure_list, batch_det_img_info)) { throw std::runtime_error( - "Failed to preprocess the input data in " + "Failed to postprocess the input data in " "StructureV2TablePostprocessor."); } return std::make_pair(boxes, structure_list); @@ -392,6 +393,8 @@ void BindPPOCRModel(pybind11::module& m) { &vision::ocr::StructureV2Table::GetPreprocessor) .def_property_readonly("postprocessor", &vision::ocr::StructureV2Table::GetPostprocessor) + .def("clone", + [](vision::ocr::StructureV2Table& self) { return self.Clone(); }) .def("predict", [](vision::ocr::StructureV2Table& self, pybind11::array& data) { auto mat = PyArrayToCvMat(data); @@ -410,5 +413,114 @@ void BindPPOCRModel(pybind11::module& m) { self.BatchPredict(images, &ocr_results); return ocr_results; }); + + // Layout + pybind11::class_(m, "StructureV2LayoutPreprocessor") + .def(pybind11::init<>()) + .def_property( + "static_shape_infer", + &vision::ocr::StructureV2LayoutPreprocessor::GetStaticShapeInfer, + &vision::ocr::StructureV2LayoutPreprocessor::SetStaticShapeInfer) + .def_property( + "layout_image_shape", + &vision::ocr::StructureV2LayoutPreprocessor::GetLayoutImageShape, + &vision::ocr::StructureV2LayoutPreprocessor::SetLayoutImageShape) + .def("set_normalize", + [](vision::ocr::StructureV2LayoutPreprocessor& self, + const std::vector& mean, const std::vector& std, + bool is_scale) { self.SetNormalize(mean, std, is_scale); }) + .def("run", + [](vision::ocr::StructureV2LayoutPreprocessor& self, + std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error( + "Failed to preprocess the input data in " + "StructureV2LayoutPreprocessor."); + } + + auto batch_layout_img_info = self.GetBatchLayoutImgInfo(); + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + + return std::make_pair(outputs, *batch_layout_img_info); + }) + .def("disable_normalize", + [](vision::ocr::StructureV2LayoutPreprocessor& self) { + self.DisableNormalize(); + }) + .def("disable_permute", + [](vision::ocr::StructureV2LayoutPreprocessor& self) { + self.DisablePermute(); + }); + + pybind11::class_( + m, "StructureV2LayoutPostprocessor") + .def(pybind11::init<>()) + .def_property( + "score_threshold", + &vision::ocr::StructureV2LayoutPostprocessor::GetScoreThreshold, + &vision::ocr::StructureV2LayoutPostprocessor::SetScoreThreshold) + .def_property( + "nms_threshold", + &vision::ocr::StructureV2LayoutPostprocessor::GetNMSThreshold, + &vision::ocr::StructureV2LayoutPostprocessor::SetNMSThreshold) + .def_property("num_class", + &vision::ocr::StructureV2LayoutPostprocessor::GetNumClass, + &vision::ocr::StructureV2LayoutPostprocessor::SetNumClass) + .def_property("fpn_stride", + &vision::ocr::StructureV2LayoutPostprocessor::GetFPNStride, + &vision::ocr::StructureV2LayoutPostprocessor::SetFPNStride) + .def_property("reg_max", + &vision::ocr::StructureV2LayoutPostprocessor::GetRegMax, + &vision::ocr::StructureV2LayoutPostprocessor::SetRegMax) + .def("run", + [](vision::ocr::StructureV2LayoutPostprocessor& self, + std::vector& inputs, + const std::vector>& batch_layout_img_info) { + std::vector results; + + if (!self.Run(inputs, &results, batch_layout_img_info)) { + throw std::runtime_error( + "Failed to postprocess the input data in " + "StructureV2LayoutPostprocessor."); + } + return results; + }); + + pybind11::class_( + m, "StructureV2Layout") + .def(pybind11::init()) + .def(pybind11::init<>()) + .def_property_readonly("preprocessor", + &vision::ocr::StructureV2Layout::GetPreprocessor) + .def_property_readonly("postprocessor", + &vision::ocr::StructureV2Layout::GetPostprocessor) + .def("clone", + [](vision::ocr::StructureV2Layout& self) { return self.Clone(); }) + .def("predict", + [](vision::ocr::StructureV2Layout& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult result; + self.Predict(mat, &result); + return result; + }) + .def("batch_predict", [](vision::ocr::StructureV2Layout& self, + std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; + }); } } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h b/fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h new file mode 100644 index 000000000..088aeea48 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h" +#include "fastdeploy/utils/unique_ptr.h" + +namespace fastdeploy { + +namespace pipeline { +typedef fastdeploy::vision::ocr::StructureV2Layout PPStructureV2Layout; + +namespace application { +namespace ocrsystem { + +// TODO(qiuyanjun): This pipeline may not need +typedef pipeline::PPStructureV2Layout PPStructureV2LayoutSystem; +} // namespace ocrsystem +} // namespace application + +} // namespace pipeline +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout.cc b/fastdeploy/vision/ocr/ppocr/structurev2_layout.cc new file mode 100644 index 000000000..d4cb33013 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h" + +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +StructureV2Layout::StructureV2Layout() {} +StructureV2Layout::StructureV2Layout(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO, + Backend::LITE}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + valid_kunlunxin_backends = {Backend::LITE}; + valid_ascend_backends = {Backend::LITE}; + valid_sophgonpu_backends = {Backend::SOPHGOTPU}; + valid_rknpu_backends = {Backend::RKNPU2}; + } + + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool StructureV2Layout::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +std::unique_ptr StructureV2Layout::Clone() const { + std::unique_ptr clone_model = + utils::make_unique(StructureV2Layout(*this)); + clone_model->SetRuntime(clone_model->CloneRuntime()); + return clone_model; +} + +bool StructureV2Layout::Predict(cv::Mat* im, DetectionResult* result) { + return Predict(*im, result); +} + +bool StructureV2Layout::Predict(const cv::Mat& im, DetectionResult* result) { + std::vector results; + if (!BatchPredict({im}, &results)) { + return false; + } + *result = std::move(results[0]); + return true; +} + +bool StructureV2Layout::BatchPredict(const std::vector& images, + std::vector* results) { + std::vector fd_images = WrapMat(images); + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + auto batch_layout_img_info = preprocessor_.GetBatchLayoutImgInfo(); + + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; + return false; + } + + if (!postprocessor_.Run(reused_output_tensors_, results, + *batch_layout_img_info)) { + FDERROR << "Failed to postprocess the inference results." << std::endl; + return false; + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout.h b/fastdeploy/vision/ocr/ppocr/structurev2_layout.h new file mode 100644 index 000000000..79b0ea60b --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h" +#include "fastdeploy/utils/unique_ptr.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { +/*! @brief StructureV2Layout object is used to load the PP-StructureV2-Layout detection model. + */ +class FASTDEPLOY_DECL StructureV2Layout : public FastDeployModel { + public: + StructureV2Layout(); + /** \brief Set path of model file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g ./picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel. + * \param[in] params_file Path of parameter file, e.g ./picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored. + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`. + * \param[in] model_format Model format of the loaded model, default is Paddle format. + */ + StructureV2Layout(const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE); + + /** \brief Clone a new StructureV2Layout with less memory usage when multiple instances of the same model are created + * + * \return newStructureV2Layout* type unique pointer + */ + virtual std::unique_ptr Clone() const; + + /// Get model's name + std::string ModelName() const { return "pp-structurev2-layout"; } + + /** \brief DEPRECATED Predict the detection result for an input image + * + * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(cv::Mat* im, DetectionResult* result); + + /** \brief Predict the detection result for an input image + * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] result The output detection result + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(const cv::Mat& im, DetectionResult* result); + + /** \brief Predict the detection result for an input image list + * \param[in] im The input image list, all the elements come from cv::imread(), is a 3-D array with layout HWC, BGR format + * \param[in] results The output detection result list + * \return true if the prediction successed, otherwise false + */ + virtual bool BatchPredict(const std::vector& imgs, + std::vector* results); + + /// Get preprocessor reference ofStructureV2LayoutPreprocessor + virtual StructureV2LayoutPreprocessor& GetPreprocessor() { + return preprocessor_; + } + + /// Get postprocessor reference ofStructureV2LayoutPostprocessor + virtual StructureV2LayoutPostprocessor& GetPostprocessor() { + return postprocessor_; + } + + private: + bool Initialize(); + StructureV2LayoutPreprocessor preprocessor_; + StructureV2LayoutPostprocessor postprocessor_; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.cc new file mode 100644 index 000000000..ad40c8f45 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.cc @@ -0,0 +1,171 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +bool StructureV2LayoutPostprocessor::Run( + const std::vector& tensors, std::vector* results, + const std::vector>& batch_layout_img_info) { + // A StructureV2Layout has 8 output tensors on which it then runs + // a GFL regression (namely, DisPred2Box), reference: + // PaddleOCR/blob/release/2.6/deploy/cpp_infer/src/postprocess_op.cpp#L511 + int tensor_size = tensors.size(); + FDASSERT(tensor_size == 8, + "StructureV2Layout should has 8 output tensors," + "but got %d now!", + tensor_size) + FDASSERT((tensor_size / 2) == fpn_stride_.size(), + "found (tensor_size / 2) != fpn_stride_.size() !") + // TODO(qiuyanjun): may need to reorder the tensors according to + // fpn_stride_ and the shape of output tensors. + size_t batch = tensors[0].Shape()[0]; // [batch, ...] + + results->resize(batch); + SetRegMax(tensors[fpn_stride_.size()].Shape()[2] / 4); + for (int batch_idx = 0; batch_idx < batch; ++batch_idx) { + std::vector single_batch_tensors(8); + SetSingleBatchExternalData(tensors, single_batch_tensors, batch_idx); + SingleBatchPostprocessor(single_batch_tensors, + batch_layout_img_info[batch_idx], + &results->at(batch_idx)); + } + return true; +} + +void StructureV2LayoutPostprocessor::SetSingleBatchExternalData( + const std::vector& tensors, + std::vector& single_batch_tensors, size_t batch_idx) { + single_batch_tensors.resize(tensors.size()); + for (int j = 0; j < tensors.size(); ++j) { + auto j_shape = tensors[j].Shape(); + j_shape[0] = 1; // process b=1 per loop + size_t j_step = + accumulate(j_shape.begin(), j_shape.end(), 1, std::multiplies()); + const float* j_data_ptr = reinterpret_cast(tensors[j].Data()); + const float* j_start_ptr = j_data_ptr + j_step * batch_idx; + single_batch_tensors[j].SetExternalData( + j_shape, tensors[j].Dtype(), + const_cast(reinterpret_cast(j_start_ptr)), + tensors[j].device, tensors[j].device_id); + } +} + +bool StructureV2LayoutPostprocessor::SingleBatchPostprocessor( + const std::vector& single_batch_tensors, + const std::array& layout_img_info, DetectionResult* result) { + FDASSERT(single_batch_tensors.size() == 8, + "StructureV2Layout should has 8 output tensors," + "but got %d now!", + static_cast(single_batch_tensors.size())) + // layout_img_info: {image width, image height, resize width, resize height} + int img_w = layout_img_info[0]; + int img_h = layout_img_info[1]; + int in_w = layout_img_info[2]; + int in_h = layout_img_info[3]; + float scale_factor_w = static_cast(in_w) / static_cast(img_w); + float scale_factor_h = static_cast(in_h) / static_cast(img_h); + + std::vector bbox_results; + bbox_results.resize(num_class_); // tmp result for each class + + // decode score, label, box + for (int i = 0; i < fpn_stride_.size(); ++i) { + int feature_h = std::ceil(static_cast(in_h) / fpn_stride_[i]); + int feature_w = std::ceil(static_cast(in_w) / fpn_stride_[i]); + const FDTensor& prob_tensor = single_batch_tensors[i]; + const FDTensor& bbox_tensor = single_batch_tensors[i + fpn_stride_.size()]; + const float* prob_data = reinterpret_cast(prob_tensor.Data()); + const float* bbox_data = reinterpret_cast(bbox_tensor.Data()); + for (int idx = 0; idx < feature_h * feature_w; ++idx) { + // score and label + float score = 0.f; + int label = 0; + for (int j = 0; j < num_class_; ++j) { + if (prob_data[idx * num_class_ + j] > score) { + score = prob_data[idx * num_class_ + j]; + label = j; + } + } + // bbox + if (score > score_threshold_) { + int row = idx / feature_w; + int col = idx % feature_w; + std::vector bbox_pred(bbox_data + idx * 4 * reg_max_, + bbox_data + (idx + 1) * 4 * reg_max_); + bbox_results[label].boxes.push_back(DisPred2Bbox( + bbox_pred, col, row, fpn_stride_[i], in_w, in_h, reg_max_)); + bbox_results[label].scores.push_back(score); + bbox_results[label].label_ids.push_back(label); + } + } + } + + result->Clear(); + // nms for per class, i in [0~num_class-1] + for (int i = 0; i < bbox_results.size(); ++i) { + if (bbox_results[i].boxes.size() <= 0) { + continue; + } + vision::utils::NMS(&bbox_results[i], nms_threshold_); + // fill output results + for (int j = 0; j < bbox_results[i].boxes.size(); ++j) { + result->scores.push_back(bbox_results[i].scores[j]); + result->label_ids.push_back(bbox_results[i].label_ids[j]); + result->boxes.push_back({ + bbox_results[i].boxes[j][0] / scale_factor_w, + bbox_results[i].boxes[j][1] / scale_factor_h, + bbox_results[i].boxes[j][2] / scale_factor_w, + bbox_results[i].boxes[j][3] / scale_factor_h, + }); + } + } + return true; +} + +std::array StructureV2LayoutPostprocessor::DisPred2Bbox( + const std::vector& bbox_pred, int x, int y, int stride, int resize_w, + int resize_h, int reg_max) { + float ct_x = (static_cast(x) + 0.5f) * static_cast(stride); + float ct_y = (static_cast(y) + 0.5f) * static_cast(stride); + std::vector dis_pred; + dis_pred.resize(4); + for (int i = 0; i < 4; i++) { + std::vector bbox_pred_i(bbox_pred.begin() + i * reg_max, + bbox_pred.begin() + (i + 1) * reg_max); + std::vector dis_after_sm = ocr::Softmax(bbox_pred_i); + float dis = 0.0f; + for (int j = 0; j < reg_max; j++) { + dis += static_cast(j) * dis_after_sm[j]; + } + dis *= static_cast(stride); + dis_pred[i] = dis; + } + + float xmin = std::max(ct_x - dis_pred[0], 0.0f); + float ymin = std::max(ct_y - dis_pred[1], 0.0f); + float xmax = std::min(ct_x + dis_pred[2], static_cast(resize_w)); + float ymax = std::min(ct_y + dis_pred[3], static_cast(resize_h)); + + return {xmin, ymin, xmax, ymax}; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h b/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h new file mode 100644 index 000000000..c1ceef7ff --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +/*! @brief Postprocessor object for PaddleDet serials model. + */ +class FASTDEPLOY_DECL StructureV2LayoutPostprocessor { + public: + StructureV2LayoutPostprocessor() {} + /** \brief Process the result of runtime and fill to batch DetectionResult + * + * \param[in] tensors The inference result from runtime + * \param[in] results The output result of layout detection + * \param[in] batch_layout_img_info The image info of input images, + * {{image width, image height, resize width, resize height},...} + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* results, + const std::vector>& batch_layout_img_info); + + /// Set score_threshold_ for layout detection postprocess, default is 0.4 + void SetScoreThreshold(float score_threshold) { score_threshold_ = score_threshold; } + /// Set nms_threshold_ for layout detection postprocess, default is 0.5 + void SetNMSThreshold(float nms_threshold) { nms_threshold_ = nms_threshold; } + /// Set num_class_ for layout detection postprocess, default is 5 + void SetNumClass(int num_class) { num_class_ = num_class; } + /// Set fpn_stride_ for layout detection postprocess, default is {8, 16, 32, 64} + void SetFPNStride(const std::vector& fpn_stride) { fpn_stride_ = fpn_stride; } + /// Set reg_max_ for layout detection postprocess, default is 8 + void SetRegMax(int reg_max) { reg_max_ = reg_max; } // should private ? + /// Get score_threshold_ of layout detection postprocess, default is 0.4 + float GetScoreThreshold() const { return score_threshold_; } + /// Get nms_threshold_ of layout detection postprocess, default is 0.5 + float GetNMSThreshold() const { return nms_threshold_; } + /// Get num_class_ of layout detection postprocess, default is 5 + int GetNumClass() const { return num_class_; } + /// Get fpn_stride_ of layout detection postprocess, default is {8, 16, 32, 64} + std::vector GetFPNStride() const { return fpn_stride_; } + /// Get reg_max_ of layout detection postprocess, default is 8 + int GetRegMax() const { return reg_max_; } + + private: + std::array DisPred2Bbox(const std::vector& bbox_pred, int x, int y, + int stride, int resize_w, int resize_h, int reg_max); + bool SingleBatchPostprocessor(const std::vector& single_batch_tensors, + const std::array& layout_img_info, + DetectionResult* result); + void SetSingleBatchExternalData(const std::vector& tensors, + std::vector& single_batch_tensors, + size_t batch_idx); + + std::vector fpn_stride_ = {8, 16, 32, 64}; + float score_threshold_ = 0.4; + float nms_threshold_ = 0.5; + int num_class_ = 5; + int reg_max_ = 8; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.cc new file mode 100644 index 000000000..a6fbc3f7b --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +StructureV2LayoutPreprocessor::StructureV2LayoutPreprocessor() { + // default width(608) and height(900) + resize_op_ = + std::make_shared(layout_image_shape_[2], layout_image_shape_[1]); + normalize_permute_op_ = std::make_shared( + std::vector({0.485f, 0.456f, 0.406f}), + std::vector({0.229f, 0.224f, 0.225f}), true); +} + +std::array StructureV2LayoutPreprocessor::GetLayoutImgInfo(FDMat* img) { + if (static_shape_infer_) { + return {img->Width(), img->Height(), layout_image_shape_[2], + layout_image_shape_[1]}; + } else { + FDASSERT(false, "not support dynamic shape inference now!") + } + return {img->Width(), img->Height(), layout_image_shape_[2], + layout_image_shape_[1]}; +} + +bool StructureV2LayoutPreprocessor::ResizeLayoutImage(FDMat* img, int resize_w, + int resize_h) { + resize_op_->SetWidthAndHeight(resize_w, resize_h); + (*resize_op_)(img); + return true; +} + +bool StructureV2LayoutPreprocessor::Apply(FDMatBatch* image_batch, + std::vector* outputs) { + batch_layout_img_info_.clear(); + batch_layout_img_info_.resize(image_batch->mats->size()); + for (size_t i = 0; i < image_batch->mats->size(); ++i) { + FDMat* mat = &(image_batch->mats->at(i)); + batch_layout_img_info_[i] = GetLayoutImgInfo(mat); + ResizeLayoutImage(mat, batch_layout_img_info_[i][2], + batch_layout_img_info_[i][3]); + } + if (!disable_normalize_ && !disable_permute_) { + (*normalize_permute_op_)(image_batch); + } + + outputs->resize(1); + FDTensor* tensor = image_batch->Tensor(); + (*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(), + tensor->Data(), tensor->device, + tensor->device_id); + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h b/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h new file mode 100644 index 000000000..f15f9f2b8 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h @@ -0,0 +1,91 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "fastdeploy/vision/common/processors/manager.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Preprocessor object for DBDetector serials model. + */ +class FASTDEPLOY_DECL StructureV2LayoutPreprocessor : public ProcessorManager { + public: + StructureV2LayoutPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] image_batch The input image batch + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs); + + /// Set preprocess normalize parameters, please call this API to customize + /// the normalize parameters, otherwise it will use the default normalize + /// parameters. + void SetNormalize(const std::vector& mean, + const std::vector& std, + bool is_scale) { + normalize_permute_op_ = + std::make_shared(mean, std, is_scale); + } + + /// Get the image info of the last batch, return a list of array + /// {image width, image height, resize width, resize height} + const std::vector>* GetBatchLayoutImgInfo() { + return &batch_layout_img_info_; + } + + /// This function will disable normalize in preprocessing step. + void DisableNormalize() { disable_permute_ = true; } + /// This function will disable hwc2chw in preprocessing step. + void DisablePermute() { disable_normalize_ = true; } + /// Set image_shape for the detection preprocess. + /// This api is usually used when you retrain the model. + /// Generally, you do not need to use it. + void SetLayoutImageShape(const std::vector& image_shape) { + layout_image_shape_ = image_shape; + } + /// Get cls_image_shape for the classification preprocess + std::vector GetLayoutImageShape() const { return layout_image_shape_; } + /// Set static_shape_infer is true or not. When deploy PP-StructureV2 + /// on hardware which can not support dynamic input shape very well, + /// like Huawei Ascned, static_shape_infer needs to to be true. + void SetStaticShapeInfer(bool static_shape_infer) { + static_shape_infer_ = static_shape_infer; + } + /// Get static_shape_infer of the recognition preprocess + bool GetStaticShapeInfer() const { return static_shape_infer_; } + + private: + bool ResizeLayoutImage(FDMat* img, int resize_w, int resize_h); + // for recording the switch of hwc2chw + bool disable_permute_ = false; + // for recording the switch of normalize + bool disable_normalize_ = false; + std::vector> batch_layout_img_info_; + std::shared_ptr resize_op_; + std::shared_ptr normalize_permute_op_; + std::vector layout_image_shape_ = {3, 800, 608}; // c,h,w + // default true for pp-structurev2-layout model, backbone picodet. + bool static_shape_infer_ = true; + std::array GetLayoutImgInfo(FDMat* img); +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h index 07ff854a3..fd6b277d5 100755 --- a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h +++ b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h @@ -34,6 +34,8 @@ FASTDEPLOY_DECL void SortBoxes(std::vector>* boxes); FASTDEPLOY_DECL std::vector ArgSort(const std::vector &array); +FASTDEPLOY_DECL std::vector Softmax(std::vector &src); + FASTDEPLOY_DECL std::vector Xyxyxyxy2Xyxy(std::array &box); FASTDEPLOY_DECL float Dis(std::vector &box1, std::vector &box2); @@ -42,7 +44,6 @@ FASTDEPLOY_DECL float Iou(std::vector &box1, std::vector &box2); FASTDEPLOY_DECL bool ComparisonDis(const std::vector &dis1, const std::vector &dis2); - } // namespace ocr } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/utils/softmax.cc b/fastdeploy/vision/ocr/ppocr/utils/softmax.cc new file mode 100644 index 000000000..6fcf7013f --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/utils/softmax.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +static inline float FastExp(float x) { + union { uint32_t i; float f; } v{}; + v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f); + return v.f; +} + +std::vector Softmax(std::vector &src) { + int length = src.size(); + std::vector dst; + dst.resize(length); + const float alpha = static_cast( + *std::max_element(&src[0], &src[0 + length])); + float denominator{0}; + + for (int i = 0; i < length; ++i) { + dst[i] = FastExp(src[i] - alpha); + denominator += dst[i]; + } + + for (int i = 0; i < length; ++i) { + dst[i] /= denominator; + } + return dst; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/visualize/detection.cc b/fastdeploy/vision/visualize/detection.cc index feb2d2bad..b0a9f0525 100644 --- a/fastdeploy/vision/visualize/detection.cc +++ b/fastdeploy/vision/visualize/detection.cc @@ -147,7 +147,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, // Visualize DetectionResult with custom labels. cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, const std::vector& labels, - float score_threshold, int line_size, float font_size) { + float score_threshold, int line_size, float font_size, + std::vector font_color, int font_thickness) { if (result.boxes.empty()) { return im; } @@ -164,6 +165,7 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, int h = im.rows; int w = im.cols; auto vis_im = im.clone(); + auto font_color_ = cv::Scalar(font_color[0], font_color[1], font_color[2]); for (size_t i = 0; i < result.rotated_boxes.size(); ++i) { if (result.scores[i] < score_threshold) { continue; @@ -195,8 +197,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, } else { end = cv::Point(static_cast(round(result.rotated_boxes[i][0])), static_cast(round(result.rotated_boxes[i][1]))); - cv::putText(vis_im, text, end, font, font_size, - cv::Scalar(255, 255, 255), 1); + cv::putText(vis_im, text, end, font, font_size, font_color_, + font_thickness); } cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA, 0); @@ -239,8 +241,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, origin.y = y1; cv::Rect rect(x1, y1, box_w, box_h); cv::rectangle(vis_im, rect, rect_color, line_size); - cv::putText(vis_im, text, origin, font, font_size, - cv::Scalar(255, 255, 255), 1); + cv::putText(vis_im, text, origin, font, font_size, font_color_, + font_thickness); if (result.contain_masks) { int mask_h = static_cast(result.masks[i].shape[0]); int mask_w = static_cast(result.masks[i].shape[1]); diff --git a/fastdeploy/vision/visualize/visualize.h b/fastdeploy/vision/visualize/visualize.h index c2d168b6e..f5ceb8558 100755 --- a/fastdeploy/vision/visualize/visualize.h +++ b/fastdeploy/vision/visualize/visualize.h @@ -81,13 +81,17 @@ FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im, * \param[in] score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold * \param[in] line_size line size for bounding boxes * \param[in] font_size font size for text + * \param[in] font_color font color for bounding text + * \param[in] font_thickness font thickness for text * \return cv::Mat type stores the visualized results */ FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, const std::vector& labels, float score_threshold = 0.0, - int line_size = 1, float font_size = 0.5f); + int line_size = 1, float font_size = 0.5f, + std::vector font_color = {255, 255, 255}, + int font_thickness = 1); /** \brief Show the visualized results with custom labels for detection models * * \param[in] im the input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format diff --git a/fastdeploy/vision/visualize/visualize_pybind.cc b/fastdeploy/vision/visualize/visualize_pybind.cc index f5d9799d0..1ecd517da 100644 --- a/fastdeploy/vision/visualize/visualize_pybind.cc +++ b/fastdeploy/vision/visualize/visualize_pybind.cc @@ -19,7 +19,8 @@ void BindVisualize(pybind11::module& m) { m.def("vis_detection", [](pybind11::array& im_data, vision::DetectionResult& result, std::vector& labels, float score_threshold, - int line_size, float font_size) { + int line_size, float font_size, std::vector font_color, + int font_thickness) { auto im = PyArrayToCvMat(im_data); cv::Mat vis_im; if (labels.empty()) { @@ -27,7 +28,8 @@ void BindVisualize(pybind11::module& m) { line_size, font_size); } else { vis_im = vision::VisDetection(im, result, labels, score_threshold, - line_size, font_size); + line_size, font_size, font_color, + font_thickness); } FDTensor out; vision::Mat(vis_im).ShareWithTensor(&out); diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index 30dcf8a83..3cd1c62be 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -650,7 +650,7 @@ class Recognizer(FastDeployModel): class StructureV2TablePreprocessor: def __init__(self): - """Create a preprocessor for StructureV2TableModel + """Create a preprocessor for StructureV2Table Model """ self._preprocessor = C.vision.ocr.StructureV2TablePreprocessor() @@ -664,12 +664,12 @@ class StructureV2TablePreprocessor: class StructureV2TablePostprocessor: def __init__(self): - """Create a postprocessor for StructureV2TableModel + """Create a postprocessor for StructureV2Table Model """ self._postprocessor = C.vision.ocr.StructureV2TablePostprocessor() def run(self, runtime_results): - """Postprocess the runtime results for StructureV2TableModel + """Postprocess the runtime results for StructureV2Table Model :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) """ @@ -683,10 +683,11 @@ class StructureV2Table(FastDeployModel): table_char_dict_path="", runtime_option=None, model_format=ModelFormat.PADDLE): - """Load OCR StructureV2Table model provided by PaddleOCR. + """Load StructureV2Table model provided by PP-StructureV2. :param model_file: (str)Path of model file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdmodel. :param params_file: (str)Path of parameter file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored. + :param table_char_dict_path: (str)Path of table_char_dict file, e.g ../ppocr/utils/dict/table_structure_dict_ch.txt :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU. :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model. """ @@ -703,8 +704,8 @@ class StructureV2Table(FastDeployModel): self._runnable = True def clone(self): - """Clone OCR StructureV2Table model object - :return: a new OCR StructureV2Table model object + """Clone StructureV2Table model object + :return: a new StructureV2Table model object """ class StructureV2TableClone(StructureV2Table): @@ -749,6 +750,105 @@ class StructureV2Table(FastDeployModel): self._model.postprocessor = value +class StructureV2LayoutPreprocessor: + def __init__(self): + """Create a preprocessor for StructureV2Layout Model + """ + self._preprocessor = C.vision.ocr.StructureV2LayoutPreprocessor() + + def run(self, input_ims): + """Preprocess input images for StructureV2Layout Model + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + +class StructureV2LayoutPostprocessor: + def __init__(self): + """Create a postprocessor for StructureV2Layout Model + """ + self._postprocessor = C.vision.ocr.StructureV2LayoutPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for StructureV2Layout Model + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + +class StructureV2Layout(FastDeployModel): + def __init__(self, + model_file="", + params_file="", + runtime_option=None, + model_format=ModelFormat.PADDLE): + """Load StructureV2Layout model provided by PP-StructureV2. + + :param model_file: (str)Path of model file, e.g ./picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel. + :param params_file: (str)Path of parameter file, e.g ./picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored. + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU. + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model. + """ + super(StructureV2Layout, self).__init__(runtime_option) + + if (len(model_file) == 0): + self._model = C.vision.ocr.StructureV2Layout() + self._runnable = False + else: + self._model = C.vision.ocr.StructureV2Layout( + model_file, params_file, self._runtime_option, model_format) + assert self.initialized, "StructureV2Layout model initialize failed." + self._runnable = True + + def clone(self): + """Clone StructureV2Layout model object + :return: a new StructureV2Table model object + """ + + class StructureV2LayoutClone(StructureV2Layout): + def __init__(self, model): + self._model = model + + clone_model = StructureV2LayoutClone(self._model.clone()) + return clone_model + + def predict(self, input_image): + """Predict an input image + :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: bboxes + """ + if self._runnable: + return self._model.predict(input_image) + return False + + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: list of bboxes list + """ + if self._runnable: + return self._model.batch_predict(images) + return False + + @property + def preprocessor(self): + return self._model.preprocessor + + @preprocessor.setter + def preprocessor(self, value): + self._model.preprocessor = value + + @property + def postprocessor(self): + return self._model.postprocessor + + @postprocessor.setter + def postprocessor(self, value): + self._model.postprocessor = value + + class PPOCRv3(FastDeployModel): def __init__(self, det_model=None, cls_model=None, rec_model=None): """Consruct a pipeline with text detector, direction classifier and text recognizer models diff --git a/python/fastdeploy/vision/visualize/__init__.py b/python/fastdeploy/vision/visualize/__init__.py index 930f49376..df74091a2 100755 --- a/python/fastdeploy/vision/visualize/__init__.py +++ b/python/fastdeploy/vision/visualize/__init__.py @@ -23,7 +23,9 @@ def vis_detection(im_data, labels=[], score_threshold=0.0, line_size=1, - font_size=0.5): + font_size=0.5, + font_color=[255, 255, 255], + font_thickness=1): """Show the visualized results for detection models :param im_data: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format @@ -32,10 +34,13 @@ def vis_detection(im_data, :param score_threshold: (float) score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold :param line_size: (float) line_size line size for bounding boxes :param font_size: (float) font_size font size for text + :param font_color: (list of int) font_color for text + :param font_thickness: (int) font_thickness for text :return: (numpy.ndarray) image with visualized results """ return C.vision.vis_detection(im_data, det_result, labels, score_threshold, - line_size, font_size) + line_size, font_size, font_color, + font_thickness) def vis_perception(im_data,