[Model] Support PP-StructureV2-Layout model (#1867)

* [Model] init pp-structurev2-layout code

* [Model] init pp-structurev2-layout code

* [Model] init pp-structurev2-layout code

* [Model] add structurev2_layout_preprocessor

* [PP-StructureV2] add postprocessor and layout detector class

* [PP-StructureV2] add postprocessor and layout detector class

* [PP-StructureV2] add postprocessor and layout detector class

* [PP-StructureV2] add postprocessor and layout detector class

* [PP-StructureV2] add postprocessor and layout detector class

* [pybind] add pp-structurev2-layout model pybind

* [pybind] add pp-structurev2-layout model pybind

* [Bug Fix] fixed code style

* [examples] add pp-structurev2-layout c++ examples

* [PP-StructureV2] add python example and docs

* [benchmark] add pp-structurev2-layout benchmark support
This commit is contained in:
DefTruth
2023-05-05 13:05:58 +08:00
committed by GitHub
parent 2c5fd91a7f
commit 6d0261e9e4
26 changed files with 1255 additions and 23 deletions

View File

@@ -23,6 +23,7 @@ add_executable(benchmark_ppocr_det ${PROJECT_SOURCE_DIR}/benchmark_ppocr_det.cc)
add_executable(benchmark_ppocr_cls ${PROJECT_SOURCE_DIR}/benchmark_ppocr_cls.cc) add_executable(benchmark_ppocr_cls ${PROJECT_SOURCE_DIR}/benchmark_ppocr_cls.cc)
add_executable(benchmark_ppocr_rec ${PROJECT_SOURCE_DIR}/benchmark_ppocr_rec.cc) add_executable(benchmark_ppocr_rec ${PROJECT_SOURCE_DIR}/benchmark_ppocr_rec.cc)
add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_structurev2_table.cc) add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_structurev2_table.cc)
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
add_executable(benchmark_ppyoloe_r ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r.cc) add_executable(benchmark_ppyoloe_r ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r.cc)
add_executable(benchmark_ppyoloe_r_sophgo ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r_sophgo.cc) add_executable(benchmark_ppyoloe_r_sophgo ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r_sophgo.cc)
add_executable(benchmark_ppyolo ${PROJECT_SOURCE_DIR}/benchmark_ppyolo.cc) add_executable(benchmark_ppyolo ${PROJECT_SOURCE_DIR}/benchmark_ppyolo.cc)
@@ -57,6 +58,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags pthread)
@@ -88,6 +90,7 @@ else()
target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppocr_cls ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppocr_rec ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_structurev2_table ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppyolo ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_yolov3 ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_fasterrcnn ${FASTDEPLOY_LIBS} gflags)

View File

@@ -0,0 +1,93 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "flags.h"
#include "macros.h"
#include "option.h"
namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;
int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
auto im = cv::imread(FLAGS_image);
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::PADDLE;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info, false)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
if (config_info["backend"] == "paddle_trt") {
option.paddle_infer_option.collect_trt_shape = true;
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
option.trt_option.SetShape("image", {1, 3, 800, 608}, {1, 3, 800, 608},
{1, 3, 800, 608});
}
auto layout_model = vision::ocr::StructureV2Layout(model_file, params_file,
option, model_format);
// 5 for publaynet, 10 for cdla
layout_model.GetPostprocessor().SetNumClass(5);
vision::DetectionResult res;
if (config_info["precision_compare"] == "true") {
// Run once at least
layout_model.Predict(im, &res);
// 1. Test result diff
std::cout << "=============== Test result diff =================\n";
// Save result to -> disk.
std::string layout_result_path = "layout_result.txt";
benchmark::ResultManager::SaveDetectionResult(res, layout_result_path);
// Load result from <- disk.
vision::DetectionResult res_loaded;
benchmark::ResultManager::LoadDetectionResult(&res_loaded,
layout_result_path);
// Calculate diff between two results.
auto det_diff =
benchmark::ResultManager::CalculateDiffStatis(res, res_loaded);
std::cout << "Boxes diff: mean=" << det_diff.boxes.mean
<< ", max=" << det_diff.boxes.max
<< ", min=" << det_diff.boxes.min << std::endl;
std::cout << "Label_ids diff: mean=" << det_diff.labels.mean
<< ", max=" << det_diff.labels.max
<< ", min=" << det_diff.labels.min << std::endl;
}
// Run profiling
BENCHMARK_MODEL(layout_model, layout_model.Predict(im, &res))
std::vector<std::string> labels = {"text", "title", "list", "table",
"figure"};
if (layout_model.GetPostprocessor().GetNumClass() == 10) {
labels = {"text", "title", "figure", "figure_caption",
"table", "table_caption", "header", "footer",
"reference", "equation"};
}
auto vis_im =
vision::VisDetection(im, res, labels, 0.3, 2, .5f, {255, 0, 0}, 2);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
#endif
return 0;
}

View File

@@ -38,3 +38,8 @@ target_link_libraries(infer_rec ${FASTDEPLOY_LIBS})
add_executable(infer_structurev2_table ${PROJECT_SOURCE_DIR}/infer_structurev2_table.cc) add_executable(infer_structurev2_table ${PROJECT_SOURCE_DIR}/infer_structurev2_table.cc)
# 添加FastDeploy库依赖 # 添加FastDeploy库依赖
target_link_libraries(infer_structurev2_table ${FASTDEPLOY_LIBS}) target_link_libraries(infer_structurev2_table ${FASTDEPLOY_LIBS})
# Only Layout
add_executable(infer_structurev2_layout ${PROJECT_SOURCE_DIR}/infer_structurev2_layout.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_structurev2_layout ${FASTDEPLOY_LIBS})

View File

@@ -46,12 +46,18 @@ tar -xvf ch_PP-OCRv3_rec_infer.tar
# 下载PPStructureV2表格识别模型 # 下载PPStructureV2表格识别模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar
tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
# 下载PP-StructureV2版面分析模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar
tar -xvf picodet_lcnet_x1_0_fgd_layout_infer.tar
# 下载预测图片与字典文件 # 下载预测图片与字典文件
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/layout.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt
# 运行部署示例 # 运行部署示例
# 在CPU上使用Paddle Inference推理 # 在CPU上使用Paddle Inference推理
@@ -71,7 +77,7 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/t
# 在GPU上使用Nvidia TensorRT推理 # 在GPU上使用Nvidia TensorRT推理
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 7 ./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 7
# 同时, FastDeploy提供文字检测,文字分类,文字识别三个模型的单独推理, # 同时, FastDeploy提供文字检测,文字分类,文字识别,表格识别,版面分析等模型的单独推理,
# 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.cc来配置自定义硬件与推理后端. # 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.cc来配置自定义硬件与推理后端.
# 在CPU上,单独使用文字检测模型部署 # 在CPU上,单独使用文字检测模型部署
@@ -85,6 +91,9 @@ wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/t
# 在CPU上,单独使用表格识别模型部署 # 在CPU上,单独使用表格识别模型部署
./infer_structurev2_table ./ch_ppstructure_mobile_v2.0_SLANet_infer ./table_structure_dict_ch.txt ./table.jpg 0 ./infer_structurev2_table ./ch_ppstructure_mobile_v2.0_SLANet_infer ./table_structure_dict_ch.txt ./table.jpg 0
# 在CPU上,单独使用版面分析模型部署
./infer_structurev2_layout ./picodet_lcnet_x1_0_fgd_layout_infer ./layout.jpg 0
``` ```
运行完成可视化结果如下图所示 运行完成可视化结果如下图所示

View File

@@ -0,0 +1,87 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string &layout_model_dir,
const std::string &image_file,
const fastdeploy::RuntimeOption &option) {
auto layout_model_file = layout_model_dir + sep + "model.pdmodel";
auto layout_params_file = layout_model_dir + sep + "model.pdiparams";
auto layout_model = fastdeploy::vision::ocr::StructureV2Layout(
layout_model_file, layout_params_file, option);
if (!layout_model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
// 5 for publaynet, 10 for cdla
layout_model.GetPostprocessor().SetNumClass(5);
fastdeploy::vision::DetectionResult res;
if (!layout_model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
std::vector<std::string> labels = {"text", "title", "list", "table",
"figure"};
if (layout_model.GetPostprocessor().GetNumClass() == 10) {
labels = {"text", "title", "figure", "figure_caption",
"table", "table_caption", "header", "footer",
"reference", "equation"};
}
auto vis_im = fastdeploy::vision::VisDetection(im, res, labels, 0.3, 2, .5f,
{255, 0, 0}, 2);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char *argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_demo path/to/layout_model path/to/image "
"run_option, "
"e.g ./infer_structurev2_layout picodet_lcnet_x1_0_fgd_layout_infer "
"layout.png 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu;."
<< std::endl;
return -1;
}
fastdeploy::RuntimeOption option;
int flag = std::atoi(argv[3]);
if (flag == 0) {
option.UseCpu();
} else if (flag == 1) {
option.UseGpu();
}
std::string layout_model_dir = argv[1];
std::string image_file = argv[2];
InitAndInfer(layout_model_dir, image_file, option);
return 0;
}

View File

@@ -39,12 +39,18 @@ tar -xvf ch_PP-OCRv3_rec_infer.tar
# 下载PPStructureV2表格识别模型 # 下载PPStructureV2表格识别模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar
tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
# 下载PP-StructureV2版面分析模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_fgd_layout_infer.tar
tar -xvf picodet_lcnet_x1_0_fgd_layout_infer.tar
# 下载预测图片与字典文件 # 下载预测图片与字典文件
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/table.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppstructure/docs/table/layout.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/table_structure_dict_ch.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt
# 运行部署示例 # 运行部署示例
# 在CPU上使用Paddle Inference推理 # 在CPU上使用Paddle Inference推理
@@ -64,7 +70,7 @@ python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2
# 在GPU上使用Nvidia TensorRT推理 # 在GPU上使用Nvidia TensorRT推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --backend trt
# 同时, FastDeploy提供文字检测,文字分类,文字识别三个模型的单独推理, # 同时, FastDeploy提供文字检测,文字分类,文字识别,表格识别,版面分析等模型的单独推理,
# 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.py来配置自定义硬件与推理后端. # 有需要的用户, 请准备合适的图片, 同时根据自己的需求, 参考infer.py来配置自定义硬件与推理后端.
# 在CPU上,单独使用文字检测模型部署 # 在CPU上,单独使用文字检测模型部署
@@ -76,8 +82,11 @@ python infer_cls.py --cls_model ch_ppocr_mobile_v2.0_cls_infer --image 12.jpg --
# 在CPU上,单独使用文字识别模型部署 # 在CPU上,单独使用文字识别模型部署
python infer_rec.py --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu python infer_rec.py --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
# 在CPU上,单独使用文字识别模型部署 # 在CPU上,单独使用表格识别模型部署
python infer_structurev2_table.py --table_model ./ch_ppstructure_mobile_v2.0_SLANet_infer --table_char_dict_path ./table_structure_dict_ch.txt --image table.jpg --device cpu python infer_structurev2_table.py --table_model ./ch_ppstructure_mobile_v2.0_SLANet_infer --table_char_dict_path ./table_structure_dict_ch.txt --image table.jpg --device cpu
# 在CPU上,单独使用版面分析模型部署
python infer_structurev2_layout.py --layout_model ./picodet_lcnet_x1_0_fgd_layout_infer --image layout.jpg --device cpu
``` ```
运行完成可视化结果如下图所示 运行完成可视化结果如下图所示

View File

@@ -0,0 +1,91 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--layout_model",
required=True,
help="Path of Layout detection model of PP-StructureV2.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--device_id",
type=int,
default=0,
help="Define which GPU card used to run model.")
return parser.parse_args()
def build_option(args):
layout_option = fd.RuntimeOption()
if args.device.lower() == "gpu":
layout_option.use_gpu(args.device_id)
return layout_option
args = parse_arguments()
layout_model_file = os.path.join(args.layout_model, "model.pdmodel")
layout_params_file = os.path.join(args.layout_model, "model.pdiparams")
# Set the runtime option
layout_option = build_option(args)
# Create the table_model
layout_model = fd.vision.ocr.StructureV2Layout(
layout_model_file, layout_params_file, layout_option)
layout_model.postprocessor.num_class = 5
# Read the image
im = cv2.imread(args.image)
# Predict and return the results
result = layout_model.predict(im)
print(result)
# Visualize the results
labels = ["text", "title", "list", "table", "figure"]
if layout_model.postprocessor.num_class == 10:
labels = [
"text", "title", "figure", "figure_caption", "table", "table_caption",
"header", "footer", "reference", "equation"
]
vis_im = fd.vision.vis_detection(
im,
result,
labels,
score_threshold=0.5,
font_color=[255, 0, 0],
font_thickness=2)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -23,7 +23,7 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--table_model", "--table_model",
required=True, required=True,
help="Path of Table recognition model of PPOCR.") help="Path of Table recognition model of PP-StructureV2.")
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
type=str, type=str,

View File

@@ -54,9 +54,11 @@
#include "fastdeploy/vision/ocr/ppocr/classifier.h" #include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/structurev2_table.h" #include "fastdeploy/vision/ocr/ppocr/structurev2_table.h"
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_table.h" #include "fastdeploy/vision/ocr/ppocr/ppstructurev2_table.h"
#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/vision/segmentation/ppseg/model.h" #include "fastdeploy/vision/segmentation/ppseg/model.h"

View File

@@ -33,6 +33,9 @@ std::ostream& operator<<(std::ostream& out, const ProcLib& p) {
case ProcLib::CUDA: case ProcLib::CUDA:
out << "ProcLib::CUDA"; out << "ProcLib::CUDA";
break; break;
case ProcLib::CVCUDA:
out << "ProcLib::CVCUDA";
break;
default: default:
FDASSERT(false, "Unknow type of ProcLib."); FDASSERT(false, "Unknow type of ProcLib.");
} }

View File

@@ -153,10 +153,16 @@ std::string DetectionResult::Str() {
std::string out; std::string out;
if (!contain_masks) { if (!contain_masks) {
out = "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]\n"; out = "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]\n";
if (!rotated_boxes.empty()) {
out = "DetectionResult: [x1, y1, x2, y2, x3, y3, x4, y4, score, label_id]\n";
}
} else { } else {
out = out =
"DetectionResult: [xmin, ymin, xmax, ymax, score, label_id, " "DetectionResult: [xmin, ymin, xmax, ymax, score, label_id, "
"mask_shape]\n"; "mask_shape]\n";
if (!rotated_boxes.empty()) {
out = "DetectionResult: [x1, y1, x2, y2, x3, y3, x4, y4, score, label_id, mask_shape]\n";
}
} }
for (size_t i = 0; i < boxes.size(); ++i) { for (size_t i = 0; i < boxes.size(); ++i) {
out = out + std::to_string(boxes[i][0]) + "," + out = out + std::to_string(boxes[i][0]) + "," +

View File

@@ -304,6 +304,7 @@ void BindPPOCRModel(pybind11::module& m) {
&vision::ocr::Recognizer::GetPreprocessor) &vision::ocr::Recognizer::GetPreprocessor)
.def_property_readonly("postprocessor", .def_property_readonly("postprocessor",
&vision::ocr::Recognizer::GetPostprocessor) &vision::ocr::Recognizer::GetPostprocessor)
.def("clone", [](vision::ocr::Recognizer& self) { return self.Clone(); })
.def("predict", .def("predict",
[](vision::ocr::Recognizer& self, pybind11::array& data) { [](vision::ocr::Recognizer& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
@@ -360,7 +361,7 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &boxes, &structure_list, if (!self.Run(inputs, &boxes, &structure_list,
batch_det_img_info)) { batch_det_img_info)) {
throw std::runtime_error( throw std::runtime_error(
"Failed to preprocess the input data in " "Failed to postprocess the input data in "
"StructureV2TablePostprocessor."); "StructureV2TablePostprocessor.");
} }
return std::make_pair(boxes, structure_list); return std::make_pair(boxes, structure_list);
@@ -377,7 +378,7 @@ void BindPPOCRModel(pybind11::module& m) {
if (!self.Run(inputs, &boxes, &structure_list, if (!self.Run(inputs, &boxes, &structure_list,
batch_det_img_info)) { batch_det_img_info)) {
throw std::runtime_error( throw std::runtime_error(
"Failed to preprocess the input data in " "Failed to postprocess the input data in "
"StructureV2TablePostprocessor."); "StructureV2TablePostprocessor.");
} }
return std::make_pair(boxes, structure_list); return std::make_pair(boxes, structure_list);
@@ -392,6 +393,8 @@ void BindPPOCRModel(pybind11::module& m) {
&vision::ocr::StructureV2Table::GetPreprocessor) &vision::ocr::StructureV2Table::GetPreprocessor)
.def_property_readonly("postprocessor", .def_property_readonly("postprocessor",
&vision::ocr::StructureV2Table::GetPostprocessor) &vision::ocr::StructureV2Table::GetPostprocessor)
.def("clone",
[](vision::ocr::StructureV2Table& self) { return self.Clone(); })
.def("predict", .def("predict",
[](vision::ocr::StructureV2Table& self, pybind11::array& data) { [](vision::ocr::StructureV2Table& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
@@ -410,5 +413,114 @@ void BindPPOCRModel(pybind11::module& m) {
self.BatchPredict(images, &ocr_results); self.BatchPredict(images, &ocr_results);
return ocr_results; return ocr_results;
}); });
// Layout
pybind11::class_<vision::ocr::StructureV2LayoutPreprocessor,
vision::ProcessorManager>(m, "StructureV2LayoutPreprocessor")
.def(pybind11::init<>())
.def_property(
"static_shape_infer",
&vision::ocr::StructureV2LayoutPreprocessor::GetStaticShapeInfer,
&vision::ocr::StructureV2LayoutPreprocessor::SetStaticShapeInfer)
.def_property(
"layout_image_shape",
&vision::ocr::StructureV2LayoutPreprocessor::GetLayoutImageShape,
&vision::ocr::StructureV2LayoutPreprocessor::SetLayoutImageShape)
.def("set_normalize",
[](vision::ocr::StructureV2LayoutPreprocessor& self,
const std::vector<float>& mean, const std::vector<float>& std,
bool is_scale) { self.SetNormalize(mean, std, is_scale); })
.def("run",
[](vision::ocr::StructureV2LayoutPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error(
"Failed to preprocess the input data in "
"StructureV2LayoutPreprocessor.");
}
auto batch_layout_img_info = self.GetBatchLayoutImgInfo();
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return std::make_pair(outputs, *batch_layout_img_info);
})
.def("disable_normalize",
[](vision::ocr::StructureV2LayoutPreprocessor& self) {
self.DisableNormalize();
})
.def("disable_permute",
[](vision::ocr::StructureV2LayoutPreprocessor& self) {
self.DisablePermute();
});
pybind11::class_<vision::ocr::StructureV2LayoutPostprocessor>(
m, "StructureV2LayoutPostprocessor")
.def(pybind11::init<>())
.def_property(
"score_threshold",
&vision::ocr::StructureV2LayoutPostprocessor::GetScoreThreshold,
&vision::ocr::StructureV2LayoutPostprocessor::SetScoreThreshold)
.def_property(
"nms_threshold",
&vision::ocr::StructureV2LayoutPostprocessor::GetNMSThreshold,
&vision::ocr::StructureV2LayoutPostprocessor::SetNMSThreshold)
.def_property("num_class",
&vision::ocr::StructureV2LayoutPostprocessor::GetNumClass,
&vision::ocr::StructureV2LayoutPostprocessor::SetNumClass)
.def_property("fpn_stride",
&vision::ocr::StructureV2LayoutPostprocessor::GetFPNStride,
&vision::ocr::StructureV2LayoutPostprocessor::SetFPNStride)
.def_property("reg_max",
&vision::ocr::StructureV2LayoutPostprocessor::GetRegMax,
&vision::ocr::StructureV2LayoutPostprocessor::SetRegMax)
.def("run",
[](vision::ocr::StructureV2LayoutPostprocessor& self,
std::vector<FDTensor>& inputs,
const std::vector<std::array<int, 4>>& batch_layout_img_info) {
std::vector<vision::DetectionResult> results;
if (!self.Run(inputs, &results, batch_layout_img_info)) {
throw std::runtime_error(
"Failed to postprocess the input data in "
"StructureV2LayoutPostprocessor.");
}
return results;
});
pybind11::class_<vision::ocr::StructureV2Layout, FastDeployModel>(
m, "StructureV2Layout")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_property_readonly("preprocessor",
&vision::ocr::StructureV2Layout::GetPreprocessor)
.def_property_readonly("postprocessor",
&vision::ocr::StructureV2Layout::GetPostprocessor)
.def("clone",
[](vision::ocr::StructureV2Layout& self) { return self.Clone(); })
.def("predict",
[](vision::ocr::StructureV2Layout& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult result;
self.Predict(mat, &result);
return result;
})
.def("batch_predict", [](vision::ocr::StructureV2Layout& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::DetectionResult> results;
self.BatchPredict(images, &results);
return results;
});
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy {
namespace pipeline {
typedef fastdeploy::vision::ocr::StructureV2Layout PPStructureV2Layout;
namespace application {
namespace ocrsystem {
// TODO(qiuyanjun): This pipeline may not need
typedef pipeline::PPStructureV2Layout PPStructureV2LayoutSystem;
} // namespace ocrsystem
} // namespace application
} // namespace pipeline
} // namespace fastdeploy

View File

@@ -0,0 +1,102 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
StructureV2Layout::StructureV2Layout() {}
StructureV2Layout::StructureV2Layout(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::OPENVINO,
Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
valid_rknpu_backends = {Backend::RKNPU2};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool StructureV2Layout::Initialize() {
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
std::unique_ptr<StructureV2Layout> StructureV2Layout::Clone() const {
std::unique_ptr<StructureV2Layout> clone_model =
utils::make_unique<StructureV2Layout>(StructureV2Layout(*this));
clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model;
}
bool StructureV2Layout::Predict(cv::Mat* im, DetectionResult* result) {
return Predict(*im, result);
}
bool StructureV2Layout::Predict(const cv::Mat& im, DetectionResult* result) {
std::vector<DetectionResult> results;
if (!BatchPredict({im}, &results)) {
return false;
}
*result = std::move(results[0]);
return true;
}
bool StructureV2Layout::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<DetectionResult>* results) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
auto batch_layout_img_info = preprocessor_.GetBatchLayoutImgInfo();
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!postprocessor_.Run(reused_output_tensors_, results,
*batch_layout_img_info)) {
FDERROR << "Failed to postprocess the inference results." << std::endl;
return false;
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief StructureV2Layout object is used to load the PP-StructureV2-Layout detection model.
*/
class FASTDEPLOY_DECL StructureV2Layout : public FastDeployModel {
public:
StructureV2Layout();
/** \brief Set path of model file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g ./picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel.
* \param[in] params_file Path of parameter file, e.g ./picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored.
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`.
* \param[in] model_format Model format of the loaded model, default is Paddle format.
*/
StructureV2Layout(const std::string& model_file,
const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
/** \brief Clone a new StructureV2Layout with less memory usage when multiple instances of the same model are created
*
* \return newStructureV2Layout* type unique pointer
*/
virtual std::unique_ptr<StructureV2Layout> Clone() const;
/// Get model's name
std::string ModelName() const { return "pp-structurev2-layout"; }
/** \brief DEPRECATED Predict the detection result for an input image
*
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output detection result
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(cv::Mat* im, DetectionResult* result);
/** \brief Predict the detection result for an input image
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output detection result
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const cv::Mat& im, DetectionResult* result);
/** \brief Predict the detection result for an input image list
* \param[in] im The input image list, all the elements come from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] results The output detection result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<DetectionResult>* results);
/// Get preprocessor reference ofStructureV2LayoutPreprocessor
virtual StructureV2LayoutPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference ofStructureV2LayoutPostprocessor
virtual StructureV2LayoutPostprocessor& GetPostprocessor() {
return postprocessor_;
}
private:
bool Initialize();
StructureV2LayoutPreprocessor preprocessor_;
StructureV2LayoutPostprocessor postprocessor_;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,171 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
bool StructureV2LayoutPostprocessor::Run(
const std::vector<FDTensor>& tensors, std::vector<DetectionResult>* results,
const std::vector<std::array<int, 4>>& batch_layout_img_info) {
// A StructureV2Layout has 8 output tensors on which it then runs
// a GFL regression (namely, DisPred2Box), reference:
// PaddleOCR/blob/release/2.6/deploy/cpp_infer/src/postprocess_op.cpp#L511
int tensor_size = tensors.size();
FDASSERT(tensor_size == 8,
"StructureV2Layout should has 8 output tensors,"
"but got %d now!",
tensor_size)
FDASSERT((tensor_size / 2) == fpn_stride_.size(),
"found (tensor_size / 2) != fpn_stride_.size() !")
// TODO(qiuyanjun): may need to reorder the tensors according to
// fpn_stride_ and the shape of output tensors.
size_t batch = tensors[0].Shape()[0]; // [batch, ...]
results->resize(batch);
SetRegMax(tensors[fpn_stride_.size()].Shape()[2] / 4);
for (int batch_idx = 0; batch_idx < batch; ++batch_idx) {
std::vector<FDTensor> single_batch_tensors(8);
SetSingleBatchExternalData(tensors, single_batch_tensors, batch_idx);
SingleBatchPostprocessor(single_batch_tensors,
batch_layout_img_info[batch_idx],
&results->at(batch_idx));
}
return true;
}
void StructureV2LayoutPostprocessor::SetSingleBatchExternalData(
const std::vector<FDTensor>& tensors,
std::vector<FDTensor>& single_batch_tensors, size_t batch_idx) {
single_batch_tensors.resize(tensors.size());
for (int j = 0; j < tensors.size(); ++j) {
auto j_shape = tensors[j].Shape();
j_shape[0] = 1; // process b=1 per loop
size_t j_step =
accumulate(j_shape.begin(), j_shape.end(), 1, std::multiplies<int>());
const float* j_data_ptr = reinterpret_cast<const float*>(tensors[j].Data());
const float* j_start_ptr = j_data_ptr + j_step * batch_idx;
single_batch_tensors[j].SetExternalData(
j_shape, tensors[j].Dtype(),
const_cast<void*>(reinterpret_cast<const void*>(j_start_ptr)),
tensors[j].device, tensors[j].device_id);
}
}
bool StructureV2LayoutPostprocessor::SingleBatchPostprocessor(
const std::vector<FDTensor>& single_batch_tensors,
const std::array<int, 4>& layout_img_info, DetectionResult* result) {
FDASSERT(single_batch_tensors.size() == 8,
"StructureV2Layout should has 8 output tensors,"
"but got %d now!",
static_cast<int>(single_batch_tensors.size()))
// layout_img_info: {image width, image height, resize width, resize height}
int img_w = layout_img_info[0];
int img_h = layout_img_info[1];
int in_w = layout_img_info[2];
int in_h = layout_img_info[3];
float scale_factor_w = static_cast<float>(in_w) / static_cast<float>(img_w);
float scale_factor_h = static_cast<float>(in_h) / static_cast<float>(img_h);
std::vector<DetectionResult> bbox_results;
bbox_results.resize(num_class_); // tmp result for each class
// decode score, label, box
for (int i = 0; i < fpn_stride_.size(); ++i) {
int feature_h = std::ceil(static_cast<float>(in_h) / fpn_stride_[i]);
int feature_w = std::ceil(static_cast<float>(in_w) / fpn_stride_[i]);
const FDTensor& prob_tensor = single_batch_tensors[i];
const FDTensor& bbox_tensor = single_batch_tensors[i + fpn_stride_.size()];
const float* prob_data = reinterpret_cast<const float*>(prob_tensor.Data());
const float* bbox_data = reinterpret_cast<const float*>(bbox_tensor.Data());
for (int idx = 0; idx < feature_h * feature_w; ++idx) {
// score and label
float score = 0.f;
int label = 0;
for (int j = 0; j < num_class_; ++j) {
if (prob_data[idx * num_class_ + j] > score) {
score = prob_data[idx * num_class_ + j];
label = j;
}
}
// bbox
if (score > score_threshold_) {
int row = idx / feature_w;
int col = idx % feature_w;
std::vector<float> bbox_pred(bbox_data + idx * 4 * reg_max_,
bbox_data + (idx + 1) * 4 * reg_max_);
bbox_results[label].boxes.push_back(DisPred2Bbox(
bbox_pred, col, row, fpn_stride_[i], in_w, in_h, reg_max_));
bbox_results[label].scores.push_back(score);
bbox_results[label].label_ids.push_back(label);
}
}
}
result->Clear();
// nms for per class, i in [0~num_class-1]
for (int i = 0; i < bbox_results.size(); ++i) {
if (bbox_results[i].boxes.size() <= 0) {
continue;
}
vision::utils::NMS(&bbox_results[i], nms_threshold_);
// fill output results
for (int j = 0; j < bbox_results[i].boxes.size(); ++j) {
result->scores.push_back(bbox_results[i].scores[j]);
result->label_ids.push_back(bbox_results[i].label_ids[j]);
result->boxes.push_back({
bbox_results[i].boxes[j][0] / scale_factor_w,
bbox_results[i].boxes[j][1] / scale_factor_h,
bbox_results[i].boxes[j][2] / scale_factor_w,
bbox_results[i].boxes[j][3] / scale_factor_h,
});
}
}
return true;
}
std::array<float, 4> StructureV2LayoutPostprocessor::DisPred2Bbox(
const std::vector<float>& bbox_pred, int x, int y, int stride, int resize_w,
int resize_h, int reg_max) {
float ct_x = (static_cast<float>(x) + 0.5f) * static_cast<float>(stride);
float ct_y = (static_cast<float>(y) + 0.5f) * static_cast<float>(stride);
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
std::vector<float> bbox_pred_i(bbox_pred.begin() + i * reg_max,
bbox_pred.begin() + (i + 1) * reg_max);
std::vector<float> dis_after_sm = ocr::Softmax(bbox_pred_i);
float dis = 0.0f;
for (int j = 0; j < reg_max; j++) {
dis += static_cast<float>(j) * dis_after_sm[j];
}
dis *= static_cast<float>(stride);
dis_pred[i] = dis;
}
float xmin = std::max(ct_x - dis_pred[0], 0.0f);
float ymin = std::max(ct_y - dis_pred[1], 0.0f);
float xmax = std::min(ct_x + dis_pred[2], static_cast<float>(resize_w));
float ymax = std::min(ct_y + dis_pred[3], static_cast<float>(resize_h));
return {xmin, ymin, xmax, ymax};
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for PaddleDet serials model.
*/
class FASTDEPLOY_DECL StructureV2LayoutPostprocessor {
public:
StructureV2LayoutPostprocessor() {}
/** \brief Process the result of runtime and fill to batch DetectionResult
*
* \param[in] tensors The inference result from runtime
* \param[in] results The output result of layout detection
* \param[in] batch_layout_img_info The image info of input images,
* {{image width, image height, resize width, resize height},...}
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results,
const std::vector<std::array<int, 4>>& batch_layout_img_info);
/// Set score_threshold_ for layout detection postprocess, default is 0.4
void SetScoreThreshold(float score_threshold) { score_threshold_ = score_threshold; }
/// Set nms_threshold_ for layout detection postprocess, default is 0.5
void SetNMSThreshold(float nms_threshold) { nms_threshold_ = nms_threshold; }
/// Set num_class_ for layout detection postprocess, default is 5
void SetNumClass(int num_class) { num_class_ = num_class; }
/// Set fpn_stride_ for layout detection postprocess, default is {8, 16, 32, 64}
void SetFPNStride(const std::vector<int>& fpn_stride) { fpn_stride_ = fpn_stride; }
/// Set reg_max_ for layout detection postprocess, default is 8
void SetRegMax(int reg_max) { reg_max_ = reg_max; } // should private ?
/// Get score_threshold_ of layout detection postprocess, default is 0.4
float GetScoreThreshold() const { return score_threshold_; }
/// Get nms_threshold_ of layout detection postprocess, default is 0.5
float GetNMSThreshold() const { return nms_threshold_; }
/// Get num_class_ of layout detection postprocess, default is 5
int GetNumClass() const { return num_class_; }
/// Get fpn_stride_ of layout detection postprocess, default is {8, 16, 32, 64}
std::vector<int> GetFPNStride() const { return fpn_stride_; }
/// Get reg_max_ of layout detection postprocess, default is 8
int GetRegMax() const { return reg_max_; }
private:
std::array<float, 4> DisPred2Bbox(const std::vector<float>& bbox_pred, int x, int y,
int stride, int resize_w, int resize_h, int reg_max);
bool SingleBatchPostprocessor(const std::vector<FDTensor>& single_batch_tensors,
const std::array<int, 4>& layout_img_info,
DetectionResult* result);
void SetSingleBatchExternalData(const std::vector<FDTensor>& tensors,
std::vector<FDTensor>& single_batch_tensors,
size_t batch_idx);
std::vector<int> fpn_stride_ = {8, 16, 32, 64};
float score_threshold_ = 0.4;
float nms_threshold_ = 0.5;
int num_class_ = 5;
int reg_max_ = 8;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,72 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout_preprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
StructureV2LayoutPreprocessor::StructureV2LayoutPreprocessor() {
// default width(608) and height(900)
resize_op_ =
std::make_shared<Resize>(layout_image_shape_[2], layout_image_shape_[1]);
normalize_permute_op_ = std::make_shared<NormalizeAndPermute>(
std::vector<float>({0.485f, 0.456f, 0.406f}),
std::vector<float>({0.229f, 0.224f, 0.225f}), true);
}
std::array<int, 4> StructureV2LayoutPreprocessor::GetLayoutImgInfo(FDMat* img) {
if (static_shape_infer_) {
return {img->Width(), img->Height(), layout_image_shape_[2],
layout_image_shape_[1]};
} else {
FDASSERT(false, "not support dynamic shape inference now!")
}
return {img->Width(), img->Height(), layout_image_shape_[2],
layout_image_shape_[1]};
}
bool StructureV2LayoutPreprocessor::ResizeLayoutImage(FDMat* img, int resize_w,
int resize_h) {
resize_op_->SetWidthAndHeight(resize_w, resize_h);
(*resize_op_)(img);
return true;
}
bool StructureV2LayoutPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
batch_layout_img_info_.clear();
batch_layout_img_info_.resize(image_batch->mats->size());
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
batch_layout_img_info_[i] = GetLayoutImgInfo(mat);
ResizeLayoutImage(mat, batch_layout_img_info_[i][2],
batch_layout_img_info_[i][3]);
}
if (!disable_normalize_ && !disable_permute_) {
(*normalize_permute_op_)(image_batch);
}
outputs->resize(1);
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,91 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for DBDetector serials model.
*/
class FASTDEPLOY_DECL StructureV2LayoutPreprocessor : public ProcessorManager {
public:
StructureV2LayoutPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] image_batch The input image batch
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
virtual bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
/// Set preprocess normalize parameters, please call this API to customize
/// the normalize parameters, otherwise it will use the default normalize
/// parameters.
void SetNormalize(const std::vector<float>& mean,
const std::vector<float>& std,
bool is_scale) {
normalize_permute_op_ =
std::make_shared<NormalizeAndPermute>(mean, std, is_scale);
}
/// Get the image info of the last batch, return a list of array
/// {image width, image height, resize width, resize height}
const std::vector<std::array<int, 4>>* GetBatchLayoutImgInfo() {
return &batch_layout_img_info_;
}
/// This function will disable normalize in preprocessing step.
void DisableNormalize() { disable_permute_ = true; }
/// This function will disable hwc2chw in preprocessing step.
void DisablePermute() { disable_normalize_ = true; }
/// Set image_shape for the detection preprocess.
/// This api is usually used when you retrain the model.
/// Generally, you do not need to use it.
void SetLayoutImageShape(const std::vector<int>& image_shape) {
layout_image_shape_ = image_shape;
}
/// Get cls_image_shape for the classification preprocess
std::vector<int> GetLayoutImageShape() const { return layout_image_shape_; }
/// Set static_shape_infer is true or not. When deploy PP-StructureV2
/// on hardware which can not support dynamic input shape very well,
/// like Huawei Ascned, static_shape_infer needs to to be true.
void SetStaticShapeInfer(bool static_shape_infer) {
static_shape_infer_ = static_shape_infer;
}
/// Get static_shape_infer of the recognition preprocess
bool GetStaticShapeInfer() const { return static_shape_infer_; }
private:
bool ResizeLayoutImage(FDMat* img, int resize_w, int resize_h);
// for recording the switch of hwc2chw
bool disable_permute_ = false;
// for recording the switch of normalize
bool disable_normalize_ = false;
std::vector<std::array<int, 4>> batch_layout_img_info_;
std::shared_ptr<Resize> resize_op_;
std::shared_ptr<NormalizeAndPermute> normalize_permute_op_;
std::vector<int> layout_image_shape_ = {3, 800, 608}; // c,h,w
// default true for pp-structurev2-layout model, backbone picodet.
bool static_shape_infer_ = true;
std::array<int, 4> GetLayoutImgInfo(FDMat* img);
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -34,6 +34,8 @@ FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes);
FASTDEPLOY_DECL std::vector<int> ArgSort(const std::vector<float> &array); FASTDEPLOY_DECL std::vector<int> ArgSort(const std::vector<float> &array);
FASTDEPLOY_DECL std::vector<float> Softmax(std::vector<float> &src);
FASTDEPLOY_DECL std::vector<int> Xyxyxyxy2Xyxy(std::array<int, 8> &box); FASTDEPLOY_DECL std::vector<int> Xyxyxyxy2Xyxy(std::array<int, 8> &box);
FASTDEPLOY_DECL float Dis(std::vector<int> &box1, std::vector<int> &box2); FASTDEPLOY_DECL float Dis(std::vector<int> &box1, std::vector<int> &box2);
@@ -42,7 +44,6 @@ FASTDEPLOY_DECL float Iou(std::vector<int> &box1, std::vector<int> &box2);
FASTDEPLOY_DECL bool ComparisonDis(const std::vector<float> &dis1, FASTDEPLOY_DECL bool ComparisonDis(const std::vector<float> &dis1,
const std::vector<float> &dis2); const std::vector<float> &dis2);
} // namespace ocr } // namespace ocr
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
static inline float FastExp(float x) {
union { uint32_t i; float f; } v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
std::vector<float> Softmax(std::vector<float> &src) {
int length = src.size();
std::vector<float> dst;
dst.resize(length);
const float alpha = static_cast<float>(
*std::max_element(&src[0], &src[0 + length]));
float denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = FastExp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return dst;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -147,7 +147,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
// Visualize DetectionResult with custom labels. // Visualize DetectionResult with custom labels.
cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
const std::vector<std::string>& labels, const std::vector<std::string>& labels,
float score_threshold, int line_size, float font_size) { float score_threshold, int line_size, float font_size,
std::vector<int> font_color, int font_thickness) {
if (result.boxes.empty()) { if (result.boxes.empty()) {
return im; return im;
} }
@@ -164,6 +165,7 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
int h = im.rows; int h = im.rows;
int w = im.cols; int w = im.cols;
auto vis_im = im.clone(); auto vis_im = im.clone();
auto font_color_ = cv::Scalar(font_color[0], font_color[1], font_color[2]);
for (size_t i = 0; i < result.rotated_boxes.size(); ++i) { for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
if (result.scores[i] < score_threshold) { if (result.scores[i] < score_threshold) {
continue; continue;
@@ -195,8 +197,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
} else { } else {
end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])), end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
static_cast<int>(round(result.rotated_boxes[i][1]))); static_cast<int>(round(result.rotated_boxes[i][1])));
cv::putText(vis_im, text, end, font, font_size, cv::putText(vis_im, text, end, font, font_size, font_color_,
cv::Scalar(255, 255, 255), 1); font_thickness);
} }
cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA, cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
0); 0);
@@ -239,8 +241,8 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
origin.y = y1; origin.y = y1;
cv::Rect rect(x1, y1, box_w, box_h); cv::Rect rect(x1, y1, box_w, box_h);
cv::rectangle(vis_im, rect, rect_color, line_size); cv::rectangle(vis_im, rect, rect_color, line_size);
cv::putText(vis_im, text, origin, font, font_size, cv::putText(vis_im, text, origin, font, font_size, font_color_,
cv::Scalar(255, 255, 255), 1); font_thickness);
if (result.contain_masks) { if (result.contain_masks) {
int mask_h = static_cast<int>(result.masks[i].shape[0]); int mask_h = static_cast<int>(result.masks[i].shape[0]);
int mask_w = static_cast<int>(result.masks[i].shape[1]); int mask_w = static_cast<int>(result.masks[i].shape[1]);

View File

@@ -81,13 +81,17 @@ FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im,
* \param[in] score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold * \param[in] score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold
* \param[in] line_size line size for bounding boxes * \param[in] line_size line size for bounding boxes
* \param[in] font_size font size for text * \param[in] font_size font size for text
* \param[in] font_color font color for bounding text
* \param[in] font_thickness font thickness for text
* \return cv::Mat type stores the visualized results * \return cv::Mat type stores the visualized results
*/ */
FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im, FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im,
const DetectionResult& result, const DetectionResult& result,
const std::vector<std::string>& labels, const std::vector<std::string>& labels,
float score_threshold = 0.0, float score_threshold = 0.0,
int line_size = 1, float font_size = 0.5f); int line_size = 1, float font_size = 0.5f,
std::vector<int> font_color = {255, 255, 255},
int font_thickness = 1);
/** \brief Show the visualized results with custom labels for detection models /** \brief Show the visualized results with custom labels for detection models
* *
* \param[in] im the input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format * \param[in] im the input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format

View File

@@ -19,7 +19,8 @@ void BindVisualize(pybind11::module& m) {
m.def("vis_detection", m.def("vis_detection",
[](pybind11::array& im_data, vision::DetectionResult& result, [](pybind11::array& im_data, vision::DetectionResult& result,
std::vector<std::string>& labels, float score_threshold, std::vector<std::string>& labels, float score_threshold,
int line_size, float font_size) { int line_size, float font_size, std::vector<int> font_color,
int font_thickness) {
auto im = PyArrayToCvMat(im_data); auto im = PyArrayToCvMat(im_data);
cv::Mat vis_im; cv::Mat vis_im;
if (labels.empty()) { if (labels.empty()) {
@@ -27,7 +28,8 @@ void BindVisualize(pybind11::module& m) {
line_size, font_size); line_size, font_size);
} else { } else {
vis_im = vision::VisDetection(im, result, labels, score_threshold, vis_im = vision::VisDetection(im, result, labels, score_threshold,
line_size, font_size); line_size, font_size, font_color,
font_thickness);
} }
FDTensor out; FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out); vision::Mat(vis_im).ShareWithTensor(&out);

View File

@@ -683,10 +683,11 @@ class StructureV2Table(FastDeployModel):
table_char_dict_path="", table_char_dict_path="",
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load OCR StructureV2Table model provided by PaddleOCR. """Load StructureV2Table model provided by PP-StructureV2.
:param model_file: (str)Path of model file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdmodel. :param model_file: (str)Path of model file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdmodel.
:param params_file: (str)Path of parameter file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored. :param params_file: (str)Path of parameter file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored.
:param table_char_dict_path: (str)Path of table_char_dict file, e.g ../ppocr/utils/dict/table_structure_dict_ch.txt
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU. :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU.
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model. :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model.
""" """
@@ -703,8 +704,8 @@ class StructureV2Table(FastDeployModel):
self._runnable = True self._runnable = True
def clone(self): def clone(self):
"""Clone OCR StructureV2Table model object """Clone StructureV2Table model object
:return: a new OCR StructureV2Table model object :return: a new StructureV2Table model object
""" """
class StructureV2TableClone(StructureV2Table): class StructureV2TableClone(StructureV2Table):
@@ -749,6 +750,105 @@ class StructureV2Table(FastDeployModel):
self._model.postprocessor = value self._model.postprocessor = value
class StructureV2LayoutPreprocessor:
def __init__(self):
"""Create a preprocessor for StructureV2Layout Model
"""
self._preprocessor = C.vision.ocr.StructureV2LayoutPreprocessor()
def run(self, input_ims):
"""Preprocess input images for StructureV2Layout Model
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class StructureV2LayoutPostprocessor:
def __init__(self):
"""Create a postprocessor for StructureV2Layout Model
"""
self._postprocessor = C.vision.ocr.StructureV2LayoutPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for StructureV2Layout Model
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class StructureV2Layout(FastDeployModel):
def __init__(self,
model_file="",
params_file="",
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load StructureV2Layout model provided by PP-StructureV2.
:param model_file: (str)Path of model file, e.g ./picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel.
:param params_file: (str)Path of parameter file, e.g ./picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams, if the model format is ONNX, this parameter will be ignored.
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU.
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model.
"""
super(StructureV2Layout, self).__init__(runtime_option)
if (len(model_file) == 0):
self._model = C.vision.ocr.StructureV2Layout()
self._runnable = False
else:
self._model = C.vision.ocr.StructureV2Layout(
model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "StructureV2Layout model initialize failed."
self._runnable = True
def clone(self):
"""Clone StructureV2Layout model object
:return: a new StructureV2Table model object
"""
class StructureV2LayoutClone(StructureV2Layout):
def __init__(self, model):
self._model = model
clone_model = StructureV2LayoutClone(self._model.clone())
return clone_model
def predict(self, input_image):
"""Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: bboxes
"""
if self._runnable:
return self._model.predict(input_image)
return False
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: list of bboxes list
"""
if self._runnable:
return self._model.batch_predict(images)
return False
@property
def preprocessor(self):
return self._model.preprocessor
@preprocessor.setter
def preprocessor(self, value):
self._model.preprocessor = value
@property
def postprocessor(self):
return self._model.postprocessor
@postprocessor.setter
def postprocessor(self, value):
self._model.postprocessor = value
class PPOCRv3(FastDeployModel): class PPOCRv3(FastDeployModel):
def __init__(self, det_model=None, cls_model=None, rec_model=None): def __init__(self, det_model=None, cls_model=None, rec_model=None):
"""Consruct a pipeline with text detector, direction classifier and text recognizer models """Consruct a pipeline with text detector, direction classifier and text recognizer models

View File

@@ -23,7 +23,9 @@ def vis_detection(im_data,
labels=[], labels=[],
score_threshold=0.0, score_threshold=0.0,
line_size=1, line_size=1,
font_size=0.5): font_size=0.5,
font_color=[255, 255, 255],
font_thickness=1):
"""Show the visualized results for detection models """Show the visualized results for detection models
:param im_data: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param im_data: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
@@ -32,10 +34,13 @@ def vis_detection(im_data,
:param score_threshold: (float) score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold :param score_threshold: (float) score_threshold threshold for result scores, the bounding box will not be shown if the score is less than score_threshold
:param line_size: (float) line_size line size for bounding boxes :param line_size: (float) line_size line size for bounding boxes
:param font_size: (float) font_size font size for text :param font_size: (float) font_size font size for text
:param font_color: (list of int) font_color for text
:param font_thickness: (int) font_thickness for text
:return: (numpy.ndarray) image with visualized results :return: (numpy.ndarray) image with visualized results
""" """
return C.vision.vis_detection(im_data, det_result, labels, score_threshold, return C.vision.vis_detection(im_data, det_result, labels, score_threshold,
line_size, font_size) line_size, font_size, font_color,
font_thickness)
def vis_perception(im_data, def vis_perception(im_data,