Add PaddleOCRv3 & PaddleOCRv2 Support (#139)

* Add PaddleOCR Support

* Add PaddleOCR Support

* Add PaddleOCRv3 Support

* Add PaddleOCRv3 Support

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Add PaddleOCRv3 Support

* Add PaddleOCRv3 Supports

* Add PaddleOCRv3 Suport

* Fix Rec diff

* Remove useless functions

* Remove useless comments

* Add PaddleOCRv2 Support
This commit is contained in:
yunyaoXYY
2022-08-27 15:09:30 +08:00
committed by GitHub
parent 820a5c5647
commit d96e98cd4d
45 changed files with 8323 additions and 2 deletions

View File

@@ -29,6 +29,13 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
if (!option.enable_log_info) {
config_.DisableGlogInfo();
}
if (!option.delete_pass_names.empty()) {
auto pass_builder = config_.pass_builder();
for (int i = 0; i < option.delete_pass_names.size(); i++) {
FDINFO << "Delete pass : " << option.delete_pass_names[i] << std::endl;
pass_builder->DeletePass(option.delete_pass_names[i]);
}
}
if (option.cpu_thread_num <= 0) {
config_.SetCpuMathLibraryNumThreads(8);
} else {
@@ -105,4 +112,4 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return true;
}
} // namespace fastdeploy
} // namespace fastdeploy

View File

@@ -40,6 +40,8 @@ struct PaddleBackendOption {
int gpu_mem_init_size = 100;
// gpu device id
int gpu_id = 0;
std::vector<std::string> delete_pass_names = {};
};
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor

View File

@@ -197,6 +197,9 @@ void RuntimeOption::EnablePaddleMKLDNN() { pd_enable_mkldnn = true; }
void RuntimeOption::DisablePaddleMKLDNN() { pd_enable_mkldnn = false; }
void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) {
pd_delete_pass_names.push_back(pass_name);
}
void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; }
@@ -307,6 +310,7 @@ void Runtime::CreatePaddleBackend() {
pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size;
pd_option.use_gpu = (option.device == Device::GPU) ? true : false;
pd_option.gpu_id = option.device_id;
pd_option.delete_pass_names = option.pd_delete_pass_names;
pd_option.cpu_thread_num = option.cpu_thread_num;
FDASSERT(option.model_format == Frontend::PADDLE,
"PaddleBackend only support model format of Frontend::PADDLE.");

View File

@@ -70,6 +70,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
void EnablePaddleMKLDNN();
// disable mkldnn while use paddle inference in CPU
void DisablePaddleMKLDNN();
// Enable delete in pass
void DeletePaddleBackendPass(const std::string& delete_pass_name);
// enable debug information of paddle backend
void EnablePaddleLogInfo();
@@ -119,6 +121,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool pd_enable_mkldnn = true;
bool pd_enable_log_info = false;
int pd_mkldnn_cache_size = 1;
std::vector<std::string> pd_delete_pass_names;
// ======Only for Trt Backend=======
std::map<std::string, std::vector<int32_t>> trt_max_shape;

View File

@@ -35,6 +35,11 @@
#include "fastdeploy/vision/faceid/contrib/partial_fc.h"
#include "fastdeploy/vision/faceid/contrib/vpl.h"
#include "fastdeploy/vision/matting/contrib/modnet.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/segmentation/ppseg/model.h"
#endif

View File

@@ -302,5 +302,68 @@ std::string MattingResult::Str() {
return out;
}
std::string OCRResult::Str() {
std::string no_result;
if (boxes.size() > 0) {
std::string out;
for (int n = 0; n < boxes.size(); n++) {
out = out + "det boxes: [";
for (int i = 0; i < 4; i++) {
out = out + "[" + std::to_string(boxes[n][i * 2]) + "," +
std::to_string(boxes[n][i * 2 + 1]) + "]";
if (i != 3) {
out = out + ",";
}
}
out = out + "]";
if (rec_scores.size() > 0) {
out = out + "rec text: " + text[n] + " rec scores:" +
std::to_string(rec_scores[n]) + " ";
}
if (cls_label.size() > 0) {
out = out + "cls label: " + std::to_string(cls_label[n]) +
" cls score: " + std::to_string(cls_scores[n]);
}
out = out + "\n";
}
return out;
} else if (boxes.size() == 0 && rec_scores.size() > 0 &&
cls_scores.size() > 0) {
std::string out;
for (int i = 0; i < rec_scores.size(); i++) {
out = out + "rec text: " + text[i] + " rec scores:" +
std::to_string(rec_scores[i]) + " ";
out = out + "cls label: " + std::to_string(cls_label[i]) +
" cls score: " + std::to_string(cls_scores[i]);
out = out + "\n";
}
return out;
} else if (boxes.size() == 0 && rec_scores.size() == 0 &&
cls_scores.size() > 0) {
std::string out;
for (int i = 0; i < cls_scores.size(); i++) {
out = out + "cls label: " + std::to_string(cls_label[i]) +
" cls score: " + std::to_string(cls_scores[i]);
out = out + "\n";
}
return out;
} else if (boxes.size() == 0 && rec_scores.size() > 0 &&
cls_scores.size() == 0) {
std::string out;
for (int i = 0; i < rec_scores.size(); i++) {
out = out + "rec text: " + text[i] + " rec scores:" +
std::to_string(rec_scores[i]) + " ";
out = out + "\n";
}
return out;
}
no_result = no_result + "No Results!";
return no_result;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -22,6 +22,7 @@ enum FASTDEPLOY_DECL ResultType {
CLASSIFY,
DETECTION,
SEGMENTATION,
OCR,
FACE_DETECTION,
FACE_RECOGNITION,
MATTING
@@ -59,6 +60,20 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
std::string Str();
};
struct FASTDEPLOY_DECL OCRResult : public BaseResult {
std::vector<std::array<int, 8>> boxes;
std::vector<std::string> text;
std::vector<float> rec_scores;
std::vector<float> cls_scores;
std::vector<int32_t> cls_label;
ResultType type = ResultType::OCR;
std::string Str();
};
struct FASTDEPLOY_DECL FaceDetectionResult : public BaseResult {
// box: xmin, ymin, xmax, ymax
std::vector<std::array<float, 4>> boxes;

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPOCRModel(pybind11::module& m);
void BindPPOCRSystemv3(pybind11::module& m);
void BindPPOCRSystemv2(pybind11::module& m);
void BindOcr(pybind11::module& m) {
auto ocr_module = m.def_submodule("ocr", "Module to deploy OCR models");
BindPPOCRModel(ocr_module);
BindPPOCRSystemv3(ocr_module);
BindPPOCRSystemv2(ocr_module);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,148 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
Classifier::Classifier() {}
Classifier::Classifier(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::TRT, Backend::ORT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
// Init
bool Classifier::Initialize() {
// pre&post process parameters
cls_thresh = 0.9;
cls_image_shape = {3, 48, 192};
cls_batch_num = 1;
mean = {0.485f, 0.456f, 0.406f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
void OcrClassifierResizeImage(Mat* mat,
const std::vector<int>& rec_image_shape) {
int imgC = rec_image_shape[0];
int imgH = rec_image_shape[1];
int imgW = rec_image_shape[2];
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
}
}
//预处理
bool Classifier::Preprocess(Mat* mat, FDTensor* output) {
// 1. cls resizes
// 2. normalize
// 3. batch_permute
OcrClassifierResizeImage(mat, cls_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
//后处理
bool Classifier::Postprocess(FDTensor& infer_result, int& cls_labels,
float& cls_scores) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
int label = std::distance(
&out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]]));
float score =
float(*std::max_element(&out_data[0], &out_data[output_shape[1]]));
cls_labels = label;
cls_scores = score;
return true;
}
//预测
bool Classifier::Predict(cv::Mat* img, int& cls_labels, float& cls_socres) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0])) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], cls_labels, cls_socres)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namesapce ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
class FASTDEPLOY_DECL Classifier : public FastDeployModel {
public:
Classifier();
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
Classifier(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
// 定义模型的名称
std::string ModelName() const { return "ppocr/ocr_cls"; }
// 模型预测接口,即用户调用的接口
virtual bool Predict(cv::Mat* img, int& cls_labels, float& cls_socres);
// pre & post parameters
float cls_thresh;
std::vector<int> cls_image_shape;
int cls_batch_num;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();
// 输入图像预处理操作
// FDTensor为预处理后的Tensor数据传给后端进行推理
bool Preprocess(Mat* img, FDTensor* output);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
bool Postprocess(FDTensor& infer_result, int& cls_labels, float& cls_scores);
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,190 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
DBDetector::DBDetector() {}
DBDetector::DBDetector(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
// Init
bool DBDetector::Initialize() {
// pre&post process parameters
max_side_len = 960;
det_db_thresh = 0.3;
det_db_box_thresh = 0.6;
det_db_unclip_ratio = 1.5;
det_db_score_mode = "slow";
use_dilation = false;
mean = {0.485f, 0.456f, 0.406f};
scale = {0.229f, 0.224f, 0.225f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
void OcrDetectorResizeImage(Mat* img, int max_size_len, float& ratio_h,
float& ratio_w) {
int w = img->Width();
int h = img->Height();
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
Resize::Run(img, resize_w, resize_h);
ratio_h = float(resize_h) / float(h);
ratio_w = float(resize_w) / float(w);
}
//预处理
bool DBDetector::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
// Resize
OcrDetectorResizeImage(mat, max_side_len, ratio_h, ratio_w);
// Normalize
Normalize::Run(mat, mean, scale, true);
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
//-CHW
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
//后处理
bool DBDetector::Postprocess(
FDTensor& infer_result, std::vector<std::vector<std::vector<int>>>* boxes,
const std::map<std::string, std::array<float, 2>>& im_info) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
int n2 = output_shape[2];
int n3 = output_shape[3];
int n = n2 * n3;
float* out_data = static_cast<float*>(infer_result.Data());
// prepare bitmap
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = det_db_thresh * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
if (use_dilation) {
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, bit_map, dila_ele);
}
post_processor_.BoxesFromBitmap(pred_map, boxes, bit_map, det_db_box_thresh,
det_db_unclip_ratio, det_db_score_mode);
post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info);
return true;
}
//预测
bool DBDetector::Predict(
cv::Mat* img, std::vector<std::vector<std::vector<int>>>* boxes_result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
std::map<std::string, std::array<float, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], boxes_result, im_info)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namesapce ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
public:
DBDetector();
DBDetector(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
// 定义模型的名称
std::string ModelName() const { return "ppocr/ocr_det"; }
// 模型预测接口,即用户调用的接口
virtual bool Predict(cv::Mat* im,
std::vector<std::vector<std::vector<int>>>* boxes);
// pre&post process parameters
int max_side_len;
float ratio_h{};
float ratio_w{};
double det_db_thresh;
double det_db_box_thresh;
double det_db_unclip_ratio;
std::string det_db_score_mode;
bool use_dilation;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();
// FDTensor为预处理后的Tensor数据传给后端进行推理
// im_info为预处理过程保存的数据在后处理中需要用到
bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);
// 后端推理结果后处理,输出给用户
bool Postprocess(FDTensor& infer_result,
std::vector<std::vector<std::vector<int>>>* boxes,
const std::map<std::string, std::array<float, 2>>& im_info);
// OCR后处理类
PostProcessor post_processor_;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,58 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/stl.h>
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPOCRModel(pybind11::module& m) {
// DBDetector
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def(pybind11::init<>())
.def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len)
.def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh)
.def_readwrite("det_db_box_thresh",
&vision::ocr::DBDetector::det_db_box_thresh)
.def_readwrite("det_db_unclip_ratio",
&vision::ocr::DBDetector::det_db_unclip_ratio)
.def_readwrite("det_db_score_mode",
&vision::ocr::DBDetector::det_db_score_mode)
.def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation)
.def_readwrite("mean", &vision::ocr::DBDetector::mean)
.def_readwrite("scale", &vision::ocr::DBDetector::scale)
.def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale);
// Classifier
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def(pybind11::init<>())
.def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh)
.def_readwrite("cls_image_shape",
&vision::ocr::Classifier::cls_image_shape)
.def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num);
// Recognizer
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>())
.def(pybind11::init<>())
.def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h)
.def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w)
.def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,56 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/stl.h>
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPOCRSystemv3(pybind11::module& m) {
// OCRSys
pybind11::class_<application::ocrsystem::PPOCRSystemv3, FastDeployModel>(
m, "PPOCRSystemv3")
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
fastdeploy::vision::ocr::Classifier*,
fastdeploy::vision::ocr::Recognizer*>())
.def("predict", [](application::ocrsystem::PPOCRSystemv3& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::OCRResult res;
self.Predict(&mat, &res);
return res;
});
}
void BindPPOCRSystemv2(pybind11::module& m) {
// OCRSys
pybind11::class_<application::ocrsystem::PPOCRSystemv2, FastDeployModel>(
m, "PPOCRSystemv2")
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
fastdeploy::vision::ocr::Classifier*,
fastdeploy::vision::ocr::Recognizer*>())
.def("predict", [](application::ocrsystem::PPOCRSystemv2& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::OCRResult res;
self.Predict(&mat, &res);
return res;
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace application {
namespace ocrsystem {
PPOCRSystemv2::PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det,
fastdeploy::vision::ocr::Classifier* ocr_cls,
fastdeploy::vision::ocr::Recognizer* ocr_rec)
: detector(ocr_det), classifier(ocr_cls), recognizer(ocr_rec) {}
void PPOCRSystemv2::Detect(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::vector<std::vector<std::vector<int>>> boxes;
this->detector->Predict(img, &boxes);
// vector<vector>转array
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
(result->boxes).push_back(new_box);
}
}
void PPOCRSystemv2::Recognize(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::string rec_texts = "";
float rec_text_scores = 0;
this->recognizer->rec_image_shape[1] =
32; // OCRv2模型此处需要设置为32其他与OCRv3一致
this->recognizer->Predict(img, rec_texts, rec_text_scores);
result->text.push_back(rec_texts);
result->rec_scores.push_back(rec_text_scores);
}
void PPOCRSystemv2::Classify(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
int cls_label = 0;
float cls_scores = 0;
this->classifier->Predict(img, cls_label, cls_scores);
result->cls_label.push_back(cls_label);
result->cls_scores.push_back(cls_scores);
}
bool PPOCRSystemv2::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
if (this->detector->initialized == 0) { //没det
//输入单张“小图片”给分类器
if (this->classifier->initialized != 0) {
this->Classify(img, result);
//摆正单张图像
if ((result->cls_label)[0] % 2 == 1 &&
(result->cls_scores)[0] > this->classifier->cls_thresh) {
cv::rotate(*img, *img, 1);
}
}
//输入单张“小图片”给识别器
if (this->recognizer->initialized != 0) {
this->Recognize(img, result);
}
} else {
//从DET模型开始
//一张图,会输出多个“小图片”,送给后续模型
this->Detect(img, result);
std::cout << "Finish Det Prediction!" << std::endl;
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < (result->boxes).size(); j++) {
cv::Mat crop_img;
crop_img =
fastdeploy::vision::ocr::GetRotateCropImage(*img, (result->boxes)[j]);
img_list.push_back(crop_img);
}
// cls
if (this->classifier->initialized != 0) {
for (int i = 0; i < img_list.size(); i++) {
this->Classify(&img_list[0], result);
}
for (int i = 0; i < img_list.size(); i++) {
if ((result->cls_label)[i] % 2 == 1 &&
(result->cls_scores)[i] > this->classifier->cls_thresh) {
std::cout << "Rotate this image " << std::endl;
cv::rotate(img_list[i], img_list[i], 1);
}
}
std::cout << "Finish Cls Prediction!" << std::endl;
}
// rec
if (this->recognizer->initialized != 0) {
for (int i = 0; i < img_list.size(); i++) {
this->Recognize(&img_list[i], result);
}
std::cout << "Finish Rec Prediction!" << std::endl;
}
}
return true;
};
} // namesapce ocrsystem
} // namespace application
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace application {
namespace ocrsystem {
class FASTDEPLOY_DECL PPOCRSystemv2 : public FastDeployModel {
public:
PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
fastdeploy::vision::ocr::DBDetector* detector = nullptr;
fastdeploy::vision::ocr::Classifier* classifier = nullptr;
fastdeploy::vision::ocr::Recognizer* recognizer = nullptr;
bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
private:
void Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
void Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
void Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
};
} // namespace ocrsystem
} // namespace application
} // namespace fastdeploy

View File

@@ -0,0 +1,128 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace application {
namespace ocrsystem {
PPOCRSystemv3::PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det,
fastdeploy::vision::ocr::Classifier* ocr_cls,
fastdeploy::vision::ocr::Recognizer* ocr_rec)
: detector(ocr_det), classifier(ocr_cls), recognizer(ocr_rec) {}
void PPOCRSystemv3::Detect(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::vector<std::vector<std::vector<int>>> boxes;
this->detector->Predict(img, &boxes);
// vector<vector>转array
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
(result->boxes).push_back(new_box);
}
}
void PPOCRSystemv3::Recognize(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::string rec_texts = "";
float rec_text_scores = 0;
this->recognizer->Predict(img, rec_texts, rec_text_scores);
result->text.push_back(rec_texts);
result->rec_scores.push_back(rec_text_scores);
}
void PPOCRSystemv3::Classify(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
int cls_label = 0;
float cls_scores = 0;
this->classifier->Predict(img, cls_label, cls_scores);
result->cls_label.push_back(cls_label);
result->cls_scores.push_back(cls_scores);
}
bool PPOCRSystemv3::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
if (this->detector->initialized == 0) { //没det
//输入单张“小图片”给分类器
if (this->classifier->initialized != 0) {
this->Classify(img, result);
//摆正单张图像
if ((result->cls_label)[0] % 2 == 1 &&
(result->cls_scores)[0] > this->classifier->cls_thresh) {
cv::rotate(*img, *img, 1);
}
}
//输入单张“小图片”给识别器
if (this->recognizer->initialized != 0) {
this->Recognize(img, result);
}
} else {
//从DET模型开始
//一张图,会输出多个“小图片”,送给后续模型
this->Detect(img, result);
std::cout << "Finish Det Prediction!" << std::endl;
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < (result->boxes).size(); j++) {
cv::Mat crop_img;
crop_img =
fastdeploy::vision::ocr::GetRotateCropImage(*img, (result->boxes)[j]);
img_list.push_back(crop_img);
}
// cls
if (this->classifier->initialized != 0) {
for (int i = 0; i < img_list.size(); i++) {
this->Classify(&img_list[0], result);
}
for (int i = 0; i < img_list.size(); i++) {
if ((result->cls_label)[i] % 2 == 1 &&
(result->cls_scores)[i] > this->classifier->cls_thresh) {
std::cout << "Rotate this image " << std::endl;
cv::rotate(img_list[i], img_list[i], 1);
}
}
std::cout << "Finish Cls Prediction!" << std::endl;
}
// rec
if (this->recognizer->initialized != 0) {
for (int i = 0; i < img_list.size(); i++) {
this->Recognize(&img_list[i], result);
}
std::cout << "Finish Rec Prediction!" << std::endl;
}
}
return true;
};
} // namesapce ocrsystem
} // namespace application
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace application {
namespace ocrsystem {
class FASTDEPLOY_DECL PPOCRSystemv3 : public FastDeployModel {
public:
PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
fastdeploy::vision::ocr::DBDetector* detector = nullptr;
fastdeploy::vision::ocr::Classifier* classifier = nullptr;
fastdeploy::vision::ocr::Recognizer* recognizer = nullptr;
bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
private:
void Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
void Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
void Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
};
} // namespace ocrsystem
} // namespace application
} // namespace fastdeploy

View File

@@ -0,0 +1,207 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
std::vector<std::string> ReadDict(const std::string& path) {
std::ifstream in(path);
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
}
return m_vec;
}
Recognizer::Recognizer() {}
Recognizer::Recognizer(const std::string& model_file,
const std::string& params_file,
const std::string& label_path,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
// NOTE:此模型暂不支持paddle-inference-Gpu推理
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
// Recognizer在使用CPU推理并把PaddleInference作为推理后端时,需要删除以下2个pass//
runtime_option.DeletePaddleBackendPass("matmul_transpose_reshape_fuse_pass");
runtime_option.DeletePaddleBackendPass(
"matmul_transpose_reshape_mkldnn_fuse_pass");
initialized = Initialize();
// init label_lsit
label_list = ReadDict(label_path);
label_list.insert(label_list.begin(), "#"); // blank char for ctc
label_list.push_back(" ");
}
// Init
bool Recognizer::Initialize() {
// pre&post process parameters
rec_batch_num = 1;
rec_img_h = 48;
rec_img_w = 320;
rec_image_shape = {3, rec_img_h, rec_img_w};
mean = {0.5f, 0.5f, 0.5f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
void OcrRecognizerResizeImage(Mat* mat, float wh_ratio,
const std::vector<int>& rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(imgH * wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {127, 127, 127};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
}
//预处理
bool Recognizer::Preprocess(Mat* mat, FDTensor* output,
const std::vector<int>& rec_image_shape) {
int imgH = rec_image_shape[1];
int imgW = rec_image_shape[2];
float wh_ratio = imgW * 1.0 / imgH;
float ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
wh_ratio = std::max(wh_ratio, ori_wh_ratio);
OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
//后处理
bool Recognizer::Postprocess(FDTensor& infer_result, std::string& rec_texts,
float& rec_text_scores) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
std::string str_res;
int argmax_idx;
int last_index = 0;
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < output_shape[1]; n++) {
argmax_idx = int(
std::distance(&out_data[n * output_shape[2]],
std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]])));
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
str_res += label_list[argmax_idx];
}
last_index = argmax_idx;
}
score /= count;
rec_texts = str_res;
rec_text_scores = score;
return true;
}
//预测
bool Recognizer::Predict(cv::Mat* img, std::string& rec_texts,
float& rec_text_scores) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], rec_texts, rec_text_scores)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namesapce ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
public:
Recognizer();
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
Recognizer(const std::string& model_file, const std::string& params_file = "",
const std::string& label_path = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
// 定义模型的名称
std::string ModelName() const { return "ppocr/ocr_rec"; }
// 模型预测接口,即用户调用的接口
virtual bool Predict(cv::Mat* img, std::string& rec_texts,
float& rec_text_scores);
// pre & post parameters
std::vector<std::string> label_list;
int rec_batch_num;
int rec_img_h;
int rec_img_w;
std::vector<int> rec_image_shape;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();
// 输入图像预处理操作
bool Preprocess(Mat* img, FDTensor* outputs,
const std::vector<int>& rec_image_shape);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
bool Postprocess(FDTensor& infer_result, std::string& rec_texts,
float& rec_text_scores);
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,425 @@
/*******************************************************************************
* *
* Author : Angus Johnson *
* Version : 6.4.2 *
* Date : 27 February 2017 *
* Website : http://www.angusj.com *
* Copyright : Angus Johnson 2010-2017 *
* *
* License: *
* Use, modification & distribution is subject to Boost Software License Ver 1. *
* http://www.boost.org/LICENSE_1_0.txt *
* *
* Attributions: *
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
* "A generic solution to polygon clipping" *
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
* http://portal.acm.org/citation.cfm?id=129906 *
* *
* Computer graphics and geometric modeling: implementation and algorithms *
* By Max K. Agoston *
* Springer; 1 edition (January 4, 2005) *
* http://books.google.com/books?q=vatti+clipping+agoston *
* *
* See also: *
* "Polygon Offsetting by Computing Winding Numbers" *
* Paper no. DETC2005-85513 pp. 565-575 *
* ASME 2005 International Design Engineering Technical Conferences *
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
* September 24-28, 2005 , Long Beach, California, USA *
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
* *
*******************************************************************************/
#pragma once
#ifndef clipper_hpp
#define clipper_hpp
#define CLIPPER_VERSION "6.4.2"
// use_int32: When enabled 32bit ints are used instead of 64bit ints. This
// improve performance but coordinate values are limited to the range +/- 46340
//#define use_int32
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
//#define use_xyz
// use_lines: Enables line clipping. Adds a very minor cost to performance.
#define use_lines
// use_deprecated: Enables temporary support for the obsolete functions
//#define use_deprecated
#include <cstdlib>
#include <cstring>
#include <functional>
#include <list>
#include <ostream>
#include <queue>
#include <set>
#include <stdexcept>
#include <vector>
namespace ClipperLib {
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
enum PolyType { ptSubject, ptClip };
// By far the most widely used winding rules for polygon filling are
// EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
// Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
// see http://glprogramming.com/red/chapter11.html
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
#ifdef use_int32
typedef int cInt;
static cInt const loRange = 0x7FFF;
static cInt const hiRange = 0x7FFF;
#else
typedef signed long long cInt;
static cInt const loRange = 0x3FFFFFFF;
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
typedef signed long long long64; // used by Int128 class
typedef unsigned long long ulong64;
#endif
struct IntPoint {
cInt X;
cInt Y;
#ifdef use_xyz
cInt Z;
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0) : X(x), Y(y), Z(z){};
#else
IntPoint(cInt x = 0, cInt y = 0) : X(x), Y(y){};
#endif
friend inline bool operator==(const IntPoint &a, const IntPoint &b) {
return a.X == b.X && a.Y == b.Y;
}
friend inline bool operator!=(const IntPoint &a, const IntPoint &b) {
return a.X != b.X || a.Y != b.Y;
}
};
//------------------------------------------------------------------------------
typedef std::vector<IntPoint> Path;
typedef std::vector<Path> Paths;
inline Path &operator<<(Path &poly, const IntPoint &p) {
poly.push_back(p);
return poly;
}
inline Paths &operator<<(Paths &polys, const Path &p) {
polys.push_back(p);
return polys;
}
std::ostream &operator<<(std::ostream &s, const IntPoint &p);
std::ostream &operator<<(std::ostream &s, const Path &p);
std::ostream &operator<<(std::ostream &s, const Paths &p);
struct DoublePoint {
double X;
double Y;
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
};
//------------------------------------------------------------------------------
#ifdef use_xyz
typedef void (*ZFillCallback)(IntPoint &e1bot, IntPoint &e1top, IntPoint &e2bot,
IntPoint &e2top, IntPoint &pt);
#endif
enum InitOptions {
ioReverseSolution = 1,
ioStrictlySimple = 2,
ioPreserveCollinear = 4
};
enum JoinType { jtSquare, jtRound, jtMiter };
enum EndType {
etClosedPolygon,
etClosedLine,
etOpenButt,
etOpenSquare,
etOpenRound
};
class PolyNode;
typedef std::vector<PolyNode *> PolyNodes;
class PolyNode {
public:
PolyNode();
virtual ~PolyNode(){};
Path Contour;
PolyNodes Childs;
PolyNode *Parent;
PolyNode *GetNext() const;
bool IsHole() const;
bool IsOpen() const;
int ChildCount() const;
private:
// PolyNode& operator =(PolyNode& other);
unsigned Index; // node index in Parent.Childs
bool m_IsOpen;
JoinType m_jointype;
EndType m_endtype;
PolyNode *GetNextSiblingUp() const;
void AddChild(PolyNode &child);
friend class Clipper; // to access Index
friend class ClipperOffset;
};
class PolyTree : public PolyNode {
public:
~PolyTree() { Clear(); };
PolyNode *GetFirst() const;
void Clear();
int Total() const;
private:
// PolyTree& operator =(PolyTree& other);
PolyNodes AllNodes;
friend class Clipper; // to access AllNodes
};
bool Orientation(const Path &poly);
double Area(const Path &poly);
int PointInPolygon(const IntPoint &pt, const Path &path);
void SimplifyPolygon(const Path &in_poly, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
void CleanPolygon(const Path &in_poly, Path &out_poly, double distance = 1.415);
void CleanPolygon(Path &poly, double distance = 1.415);
void CleanPolygons(const Paths &in_polys, Paths &out_polys,
double distance = 1.415);
void CleanPolygons(Paths &polys, double distance = 1.415);
void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution,
bool pathIsClosed);
void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution,
bool pathIsClosed);
void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution);
void PolyTreeToPaths(const PolyTree &polytree, Paths &paths);
void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths);
void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths);
void ReversePath(Path &p);
void ReversePaths(Paths &p);
struct IntRect {
cInt left;
cInt top;
cInt right;
cInt bottom;
};
// enums that are used internally ...
enum EdgeSide { esLeft = 1, esRight = 2 };
// forward declarations (for stuff used internally) ...
struct TEdge;
struct IntersectNode;
struct LocalMinimum;
struct OutPt;
struct OutRec;
struct Join;
typedef std::vector<OutRec *> PolyOutList;
typedef std::vector<TEdge *> EdgeList;
typedef std::vector<Join *> JoinList;
typedef std::vector<IntersectNode *> IntersectList;
//------------------------------------------------------------------------------
// ClipperBase is the ancestor to the Clipper class. It should not be
// instantiated directly. This class simply abstracts the conversion of sets of
// polygon coordinates into edge objects that are stored in a LocalMinima list.
class ClipperBase {
public:
ClipperBase();
virtual ~ClipperBase();
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
virtual void Clear();
IntRect GetBounds();
bool PreserveCollinear() { return m_PreserveCollinear; };
void PreserveCollinear(bool value) { m_PreserveCollinear = value; };
protected:
void DisposeLocalMinimaList();
TEdge *AddBoundsToLML(TEdge *e, bool IsClosed);
virtual void Reset();
TEdge *ProcessBound(TEdge *E, bool IsClockwise);
void InsertScanbeam(const cInt Y);
bool PopScanbeam(cInt &Y);
bool LocalMinimaPending();
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
OutRec *CreateOutRec();
void DisposeAllOutRecs();
void DisposeOutRec(PolyOutList::size_type index);
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
void DeleteFromAEL(TEdge *e);
void UpdateEdgeIntoAEL(TEdge *&e);
typedef std::vector<LocalMinimum> MinimaList;
MinimaList::iterator m_CurrentLM;
MinimaList m_MinimaList;
bool m_UseFullRange;
EdgeList m_edges;
bool m_PreserveCollinear;
bool m_HasOpenPaths;
PolyOutList m_PolyOuts;
TEdge *m_ActiveEdges;
typedef std::priority_queue<cInt> ScanbeamList;
ScanbeamList m_Scanbeam;
};
//------------------------------------------------------------------------------
class Clipper : public virtual ClipperBase {
public:
Clipper(int initOptions = 0);
bool Execute(ClipType clipType, Paths &solution,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, Paths &solution, PolyFillType subjFillType,
PolyFillType clipFillType);
bool Execute(ClipType clipType, PolyTree &polytree,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, PolyTree &polytree, PolyFillType subjFillType,
PolyFillType clipFillType);
bool ReverseSolution() { return m_ReverseOutput; };
void ReverseSolution(bool value) { m_ReverseOutput = value; };
bool StrictlySimple() { return m_StrictSimple; };
void StrictlySimple(bool value) { m_StrictSimple = value; };
// set the callback function for z value filling on intersections (otherwise Z
// is 0)
#ifdef use_xyz
void ZFillFunction(ZFillCallback zFillFunc);
#endif
protected:
virtual bool ExecuteInternal();
private:
JoinList m_Joins;
JoinList m_GhostJoins;
IntersectList m_IntersectList;
ClipType m_ClipType;
typedef std::list<cInt> MaximaList;
MaximaList m_Maxima;
TEdge *m_SortedEdges;
bool m_ExecuteLocked;
PolyFillType m_ClipFillType;
PolyFillType m_SubjFillType;
bool m_ReverseOutput;
bool m_UsingPolyTree;
bool m_StrictSimple;
#ifdef use_xyz
ZFillCallback m_ZFill; // custom callback
#endif
void SetWindingCount(TEdge &edge);
bool IsEvenOddFillType(const TEdge &edge) const;
bool IsEvenOddAltFillType(const TEdge &edge) const;
void InsertLocalMinimaIntoAEL(const cInt botY);
void InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge);
void AddEdgeToSEL(TEdge *edge);
bool PopEdgeFromSEL(TEdge *&edge);
void CopyAELToSEL();
void DeleteFromSEL(TEdge *e);
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
bool IsContributing(const TEdge &edge) const;
bool IsTopHorz(const cInt XPos);
void DoMaxima(TEdge *e);
void ProcessHorizontals();
void ProcessHorizontal(TEdge *horzEdge);
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutPt *AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutRec *GetOutRec(int idx);
void AppendPolygon(TEdge *e1, TEdge *e2);
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
OutPt *AddOutPt(TEdge *e, const IntPoint &pt);
OutPt *GetLastOutPt(TEdge *e);
bool ProcessIntersections(const cInt topY);
void BuildIntersectList(const cInt topY);
void ProcessIntersectList();
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
void BuildResult(Paths &polys);
void BuildResult2(PolyTree &polytree);
void SetHoleState(TEdge *e, OutRec *outrec);
void DisposeIntersectNodes();
bool FixupIntersectionOrder();
void FixupOutPolygon(OutRec &outrec);
void FixupOutPolyline(OutRec &outrec);
bool IsHole(TEdge *e);
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
void FixHoleLinkage(OutRec &outrec);
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
void ClearJoins();
void ClearGhostJoins();
void AddGhostJoin(OutPt *op, const IntPoint offPt);
bool JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2);
void JoinCommonEdges();
void DoSimplePolygons();
void FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec);
void FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec);
void FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec);
#ifdef use_xyz
void SetZ(IntPoint &pt, TEdge &e1, TEdge &e2);
#endif
};
//------------------------------------------------------------------------------
class ClipperOffset {
public:
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
~ClipperOffset();
void AddPath(const Path &path, JoinType joinType, EndType endType);
void AddPaths(const Paths &paths, JoinType joinType, EndType endType);
void Execute(Paths &solution, double delta);
void Execute(PolyTree &solution, double delta);
void Clear();
double MiterLimit;
double ArcTolerance;
private:
Paths m_destPolys;
Path m_srcPoly;
Path m_destPoly;
std::vector<DoublePoint> m_normals;
double m_delta, m_sinA, m_sin, m_cos;
double m_miterLim, m_StepsPerRad;
IntPoint m_lowest;
PolyNode m_polyNodes;
void FixOrientations();
void DoOffset(double delta);
void OffsetPoint(int j, int &k, JoinType jointype);
void DoSquare(int j, int k);
void DoMiter(int j, int k, double r);
void DoRound(int j, int k);
};
//------------------------------------------------------------------------------
class clipperException : public std::exception {
public:
clipperException(const char *description) : m_descr(description) {}
virtual ~clipperException() throw() {}
virtual const char *what() const throw() { return m_descr.c_str(); }
private:
std::string m_descr;
};
//------------------------------------------------------------------------------
} // ClipperLib namespace
#endif // clipper_hpp

View File

@@ -0,0 +1,89 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
const std::array<int, 8>& box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points;
for (int i = 0; i < 4; ++i) {
std::vector<int> tmp;
tmp.push_back(box[2 * i]);
tmp.push_back(box[2 * i + 1]);
points.push_back(tmp);
}
// box转points
int x_collect[4] = {box[0], box[2], box[4], box[6]};
int y_collect[4] = {box[1], box[3], box[5], box[7]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
//得到rect矩形
cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
//透视变换矩阵
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
//完成透视变换
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
} else {
return dst_img;
}
}
} // namesoace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,365 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ocr_postprocess_op.h"
#include <map>
#include "clipper.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
//获取轮廓区域
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance) {
int pts_num = 4;
float area = 0.0f;
float dist = 0.0f;
for (int i = 0; i < pts_num; i++) {
area += box[i][0] * box[(i + 1) % pts_num][1] -
box[i][1] * box[(i + 1) % pts_num][0];
dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
(box[i][0] - box[(i + 1) % pts_num][0]) +
(box[i][1] - box[(i + 1) % pts_num][1]) *
(box[i][1] - box[(i + 1) % pts_num][1]));
}
area = fabs(float(area / 2.0));
distance = area * unclip_ratio / dist;
}
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio) {
float distance = 1.0;
GetContourArea(box, unclip_ratio, distance);
ClipperLib::ClipperOffset offset;
ClipperLib::Path p;
p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
<< ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
<< ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
<< ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths soln;
offset.Execute(soln, distance);
std::vector<cv::Point2f> points;
for (int j = 0; j < soln.size(); j++) {
for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
points.emplace_back(soln[j][i].X, soln[j][i].Y);
}
}
cv::RotatedRect res;
if (points.size() <= 0) {
res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
} else {
res = cv::minAreaRect(points);
}
return res;
}
//将图像的矩阵转换为float类型的array数组返回
float **PostProcessor::Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
for (int i = 0; i < mat.rows; ++i) array[i] = new float[mat.cols];
for (int i = 0; i < mat.rows; ++i) {
for (int j = 0; j < mat.cols; ++j) {
array[i][j] = mat.at<float>(i, j);
}
}
return array;
}
//对点进行顺时针方向的排序(从左到右,从上到下) order points
// clockwise[顺时针方向]
std::vector<std::vector<int>> PostProcessor::OrderPointsClockwise(
std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
std::sort(box.begin(), box.end(), XsortInt);
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
if (leftmost[0][1] > leftmost[1][1]) std::swap(leftmost[0], leftmost[1]);
if (rightmost[0][1] > rightmost[1][1]) std::swap(rightmost[0], rightmost[1]);
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
leftmost[1]};
return rect;
}
//将图像的矩阵转换为float类型的vector数组返回
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> img_vec;
std::vector<float> tmp;
for (int i = 0; i < mat.rows; ++i) {
tmp.clear();
for (int j = 0; j < mat.cols; ++j) {
tmp.push_back(mat.at<float>(i, j));
}
img_vec.push_back(tmp);
}
return img_vec;
}
//判断元素为浮点数float的vector的精度如果a中元素的精度不等于b中元素的精度则返回false
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
if (a[0] != b[0]) return a[0] < b[0];
return false;
}
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
if (a[0] != b[0]) return a[0] < b[0];
return false;
}
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
float &ssid) {
ssid = std::max(box.size.width, box.size.height);
cv::Mat points;
cv::boxPoints(box, points);
auto array = Mat2Vector(points);
std::sort(array.begin(), array.end(), XsortFp32);
std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
idx4 = array[3];
if (array[3][1] <= array[2][1]) {
idx2 = array[3];
idx3 = array[2];
} else {
idx2 = array[2];
idx3 = array[3];
}
if (array[1][1] <= array[0][1]) {
idx1 = array[1];
idx4 = array[0];
} else {
idx1 = array[0];
idx4 = array[1];
}
array[0] = idx1;
array[1] = idx2;
array[2] = idx3;
array[3] = idx4;
return array;
}
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
std::vector<float> box_y;
for (int i = 0; i < contour.size(); ++i) {
box_x.push_back(contour[i].x);
box_y.push_back(contour[i].y);
}
int xmin =
clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int xmax =
clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int ymin =
clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
height - 1);
int ymax =
clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point *rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
delete[] rook_point;
return score;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
int height = pred.rows;
float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};
int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
width - 1);
int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
width - 1);
int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
height - 1);
int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point root_point[4];
root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
const cv::Point *ppt[1] = {root_point};
int npt[] = {4};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
auto score = cv::mean(croppedImg, mask)[0];
return score;
}
//这个应该是DB差分二值化相关的内容方法从 Bitmap 图中获取检测框
//涉及到box_thresh低于这个阈值的boxs不予显示和det_db_unclip_ratio文本框扩张的系数关系到文本框的大小
void PostProcessor::BoxesFromBitmap(
const cv::Mat pred, std::vector<std::vector<std::vector<int>>> *boxes,
const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
const int min_size = 3;
const int max_candidates = 1000;
int width = bitmap.cols;
int height = bitmap.rows;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
cv::CHAIN_APPROX_SIMPLE);
int num_contours =
contours.size() >= max_candidates ? max_candidates : contours.size();
for (int _i = 0; _i < num_contours; _i++) {
if (contours[_i].size() <= 2) {
continue;
}
float ssid;
cv::RotatedRect box = cv::minAreaRect(contours[_i]);
auto array = GetMiniBoxes(box, ssid);
auto box_for_unclip = array;
// end get_mini_box
if (ssid < min_size) {
continue;
}
float score;
if (det_db_score_mode == "slow") /* compute using polygon*/
score = PolygonScoreAcc(contours[_i], pred);
else
score = BoxScoreFast(array, pred);
if (score < box_thresh) continue;
// start for unclip
cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio);
if (points.size.height < 1.001 && points.size.width < 1.001) {
continue;
}
// end for unclip
cv::RotatedRect clipbox = points;
auto cliparray = GetMiniBoxes(clipbox, ssid);
if (ssid < min_size + 2) continue;
int dest_width = pred.cols;
int dest_height = pred.rows;
std::vector<std::vector<int>> intcliparray;
for (int num_pt = 0; num_pt < 4; num_pt++) {
std::vector<int> a{
int(clampf(
roundf(cliparray[num_pt][0] / float(width) * float(dest_width)),
0, float(dest_width))),
int(clampf(
roundf(cliparray[num_pt][1] / float(height) * float(dest_height)),
0, float(dest_height)))};
intcliparray.push_back(a);
}
boxes->push_back(intcliparray);
} // end for
// return true;
}
//方法根据识别结果获取目标框位置
void PostProcessor::FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> *boxes, float ratio_h,
float ratio_w, const std::map<std::string, std::array<float, 2>> &im_info) {
int oriimg_h = im_info.at("input_shape")[0];
int oriimg_w = im_info.at("input_shape")[1];
for (int n = 0; n < boxes->size(); n++) {
(*boxes)[n] = OrderPointsClockwise((*boxes)[n]);
for (int m = 0; m < (*boxes)[0].size(); m++) {
(*boxes)[n][m][0] /= ratio_w;
(*boxes)[n][m][1] /= ratio_h;
(*boxes)[n][m][0] = int(_min(_max((*boxes)[n][m][0], 0), oriimg_w - 1));
(*boxes)[n][m][1] = int(_min(_max((*boxes)[n][m][1], 0), oriimg_h - 1));
}
}
//此时已经拿到所有的点. 再进行下面的筛选
for (int n = (*boxes).size() - 1; n >= 0; n--) {
int rect_width, rect_height;
rect_width = int(sqrt(pow((*boxes)[n][0][0] - (*boxes)[n][1][0], 2) +
pow((*boxes)[n][0][1] - (*boxes)[n][1][1], 2)));
rect_height = int(sqrt(pow((*boxes)[n][0][0] - (*boxes)[n][3][0], 2) +
pow((*boxes)[n][0][1] - (*boxes)[n][3][1], 2)));
//原始实现小于4的跳过只return大于4的
// if (rect_width <= 4 || rect_height <= 4) continue;
// root_points.push_back((*boxes)[n]);
//小于4的删除掉. erase配合逆序遍历.
if (rect_width <= 4 || rect_height <= 4) {
boxes->erase(boxes->begin() + n);
}
}
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,93 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iomanip>
#include <iostream>
#include <map>
#include <ostream>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <cstring>
#include <fstream>
#include <numeric>
#include "fastdeploy/vision/ocr/ppocr/utils/clipper.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
class PostProcessor {
public:
void GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance);
cv::RotatedRect UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio);
float **Mat2Vec(cv::Mat mat);
std::vector<std::vector<int>> OrderPointsClockwise(
std::vector<std::vector<int>> pts);
std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box,
float &ssid);
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
void BoxesFromBitmap(const cv::Mat pred,
std::vector<std::vector<std::vector<int>>> *boxes,
const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio,
const std::string &det_db_score_mode);
void FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> *boxes, float ratio_h,
float ratio_w,
const std::map<std::string, std::array<float, 2>> &im_info);
private:
static bool XsortInt(std::vector<int> a, std::vector<int> b);
static bool XsortFp32(std::vector<float> a, std::vector<float> b);
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
inline int _max(int a, int b) { return a >= b ? a : b; }
inline int _min(int a, int b) { return a >= b ? b : a; }
template <class T>
inline T clamp(T x, T min, T max) {
if (x > max) return max;
if (x < min) return min;
return x;
}
inline float clampf(float x, float min, float max) {
if (x > max) return max;
if (x < min) return min;
return x;
}
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/result.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
namespace fastdeploy {
namespace vision {
namespace ocr {
cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
const std::array<int, 8> &box);
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -22,6 +22,7 @@ void BindSegmentation(pybind11::module& m);
void BindMatting(pybind11::module& m);
void BindFaceDet(pybind11::module& m);
void BindFaceId(pybind11::module& m);
void BindOcr(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
#endif
@@ -42,6 +43,15 @@ void BindVision(pybind11::module& m) {
.def("__repr__", &vision::DetectionResult::Str)
.def("__str__", &vision::DetectionResult::Str);
pybind11::class_<vision::OCRResult>(m, "OCRResult")
.def(pybind11::init())
.def_readwrite("boxes", &vision::OCRResult::boxes)
.def_readwrite("text", &vision::OCRResult::text)
.def_readwrite("score", &vision::OCRResult::rec_scores)
.def_readwrite("cls_score", &vision::OCRResult::cls_scores)
.def_readwrite("cls_label", &vision::OCRResult::cls_label)
.def("__repr__", &vision::OCRResult::Str)
.def("__str__", &vision::OCRResult::Str);
pybind11::class_<vision::FaceDetectionResult>(m, "FaceDetectionResult")
.def(pybind11::init())
.def_readwrite("boxes", &vision::FaceDetectionResult::boxes)
@@ -81,6 +91,7 @@ void BindVision(pybind11::module& m) {
BindFaceDet(m);
BindFaceId(m);
BindMatting(m);
BindOcr(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef ENABLE_VISION_VISUALIZE
#include "fastdeploy/vision/visualize/visualize.h"
#include "opencv2/imgproc/imgproc.hpp"
namespace fastdeploy {
namespace vision {
cv::Mat Visualize::VisOcr(const cv::Mat &im, const OCRResult &ocr_result) {
auto vis_im = im.clone();
for (int n = 0; n < ocr_result.boxes.size(); n++) {
//遍历每一个盒子
cv::Point rook_points[4];
for (int m = 0; m < 4; m++) {
//对每一个盒子 array<float,8>
rook_points[m] = cv::Point(int(ocr_result.boxes[n][m * 2]),
int(ocr_result.boxes[n][m * 2 + 1]));
}
const cv::Point *ppt[1] = {rook_points};
int npt[] = {4};
cv::polylines(vis_im, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
}
return vis_im;
}
} // namespace vision
} // namespace fastdeploy
#endif

View File

@@ -35,6 +35,7 @@ class FASTDEPLOY_DECL Visualize {
const SegmentationResult& result);
static cv::Mat VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
bool remove_small_connected_area = false);
static cv::Mat VisOcr(const cv::Mat& srcimg, const OCRResult& ocr_result);
};
} // namespace vision

View File

@@ -48,6 +48,14 @@ void BindVisualize(pybind11::module& m) {
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def_static("vis_ppocr",
[](pybind11::array& im_data, vision::OCRResult& result) {
auto im = PyArrayToCvMat(im_data);
auto vis_im = vision::Visualize::VisOcr(im, result);
FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def_static("vis_matting_alpha",
[](pybind11::array& im_data, vision::MattingResult& result,
bool remove_small_connected_area) {

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,139 @@
# PPOCRSystemv2 C++部署示例
本目录下提供`infer.cc`快速完成PPOCRSystemv2在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/quick_start)
以Linux上CPU推理为例在本目录执行如下命令即可完成编译测试
```
mkdir build
cd build
wget https://https://bj.bcebos.com/paddlehub/fastdeploy/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
tar xvf fastdeploy-linux-x64-0.2.0.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0
make -j
# 下载模型,图片和label文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
tar xvf ch_PP-OCRv2_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar
tar xvf ch_PP-OCRv2_rec_infer.tar
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
# CPU推理
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
# GPU推理
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1
# GPU上TensorRT推理
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2
# OCR还支持det/cls/rec三个模型的组合使用例如当我们不想使用cls模型的时候只需要给cls模型路径的位置传入一个空的字符串, 例子如下
./infer_demo ./ch_PP-OCRv2_det_infer "" ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
```
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
## PPOCRSystemv2 C++接口
### PPOCRSystemv2类
```
fastdeploy::application::ocrsystem::PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
```
PPOCRSystemv2 的初始化,由检测,分类和识别模型串联构成
**参数**
> * **DBDetector**(model): OCR中的检测模型
> * **Classifier**(model): OCR中的分类模型
> * **Recognizer**(model): OCR中的识别模型
#### Predict函数
> ```
> std::vector<std::vector<fastdeploy::vision::OCRResult>> ocr_results =
> PPOCRSystemv2.Predict(std::vector<cv::Mat> cv_all_imgs);
>
> ```
>
> 模型预测接口,输入一个可装入多张图片的图片列表,后可输出检测结果。
>
> **参数**
>
> > * **cv_all_imgs**: 输入图像注意需为HWCBGR格式
> > * **ocr_results**: OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果, OCRResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
## DBDetector C++接口
### DBDetector类
```
fastdeploy::vision::ocr::DBDetector(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
```
DBDetector模型加载和初始化其中模型为paddle模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为Paddle格式
### Classifier类与DBDetector类相同
### Recognizer类
```
Recognizer(const std::string& model_file,
const std::string& params_file = "",
const std::string& label_path = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
```
Recognizer类初始化时,需要在label_path参数中,输入识别模型所需的label文件其他参数均与DBDetector类相同
**参数**
> * **label_path**(str): 识别模型的label文件路径
### 类成员变量
#### DBDetector预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸当长边超出这个值时会将长边resize到这个大小短边等比例缩放,默认为960
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值默认为0.3
> > * **det_db_box_thresh**(double): DB模型输出框的阈值低于此值的预测框会被丢弃默认为0.6
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例默认为1.5
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow即求polygon区域的平均分数的方式
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
#### Classifier预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值输入的图片将被翻转默认为0.9
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,295 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseCpu();
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseCpu();
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseCpu();
rec_option.UsePaddleBackend(); // OCRv2的rec模型暂不支持ORT后端
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv2_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
//准备模型
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseGpu();
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseGpu();
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseGpu();
rec_option
.UsePaddleBackend(); // OCRv2的rec模型暂不支持ORT后端与PaddleInference
// v2.3.2
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv2_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void TrtInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
//准备模型
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseGpu();
det_option.UseTrtBackend();
det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640},
{1, 3, 960, 960});
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseGpu();
cls_option.UseTrtBackend();
cls_option.SetTrtInputShape("x", {1, 3, 48, 192});
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseGpu();
rec_option.UseTrtBackend();
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320},
{1, 3, 48, 2000});
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv2_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 7) {
std::cout << "Usage: infer_demo path/to/det_model path/to/cls_model "
"path/to/rec_model path/to/rec_label_file path/to/image "
"run_option, "
"e.g ./infer_demo ./ch_PP-OCRv2_det_infer "
"./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer "
"./ppocr_keys_v1.txt ./12.jpg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[6]) == 0) {
CpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
} else if (std::atoi(argv[6]) == 1) {
GpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
} else if (std::atoi(argv[6]) == 2) {
TrtInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
}
return 0;
}

View File

@@ -0,0 +1,131 @@
# PPOCRSystemv2 Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start)
本目录下提供`infer.py`快速完成PPOCRSystemv2在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```
# 下载模型,图片和label文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
tar xvf ch_PP-OCRv2_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar
tar xvf ch_PP-OCRv2_rec_infer.tar
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vison/ocr/PPOCRSystemv2/python/
# CPU推理
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
# GPU推理
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu
# GPU上使用TensorRT推理
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --det_use_trt True --cls_use_trt True --rec_use_trt True
# OCR还支持det/cls/rec三个模型的组合使用例如当我们不想使用cls模型的时候只需要给--cls_model传入一个空的字符串, 例子如下:
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model "" --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
```
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
## PPOCRSystemv2 Python接口
```
fastdeploy.vision.ocr.PPOCRSystemv2(ocr_det = det_model._model, ocr_cls = cls_model._model, ocr_rec = rec_model._model)
```
PPOCRSystemv2的初始化,输入的参数是检测模型,分类模型和识别模型
**参数**
> * **ocr_det**(model): OCR中的检测模型
> * **ocr_cls**(model): OCR中的分类模型
> * **ocr_rec**(model): OCR中的识别模型
### predict函数
> ```
> result = PPOCRSystemv2.predict(img_list)
> ```
>
> 模型预测接口输入的是一个可包含多个图像的list
>
> **参数**
>
> > * **img_list**(list[np.ndarray]): 输入数据的list每张图片注意需为HWCBGR格式
> > * **result**(float): OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果,
> **返回**
>
> > 返回`fastdeploy.vision.OCRResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## DBDetector Python接口
### DBDetector类
```
fastdeploy.vision.ocr.DBDetector(model_file, params_file, runtime_option=None, model_format=Frontend.PADDLE)
```
DBDetector模型加载和初始化其中模型为paddle模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为PADDLE格式
### Classifier类与DBDetector类相同
### Recognizer类
```
fastdeploy.vision.ocr.Recognizer(rec_model_file,rec_params_file,rec_label_file,
runtime_option=rec_runtime_option,model_format=Frontend.PADDLE)
```
Recognizer类初始化时,需要在rec_label_file参数中,输入识别模型所需的label文件路径其他参数均与DBDetector类相同
**参数**
> * **label_path**(str): 识别模型的label文件路径
### 类成员变量
#### DBDetector预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸当长边超出这个值时会将长边resize到这个大小短边等比例缩放,默认为960
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值默认为0.3
> > * **det_db_box_thresh**(double): DB模型输出框的阈值低于此值的预测框会被丢弃默认为0.6
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例默认为1.5
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow即求polygon区域的平均分数的方式
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
#### Classifier预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值输入的图片将被翻转默认为0.9
## 其它文档
- [YOLOv5 模型介绍](..)
- [YOLOv5 C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,146 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model", required=True, help="Path of Detection model of PPOCR.")
parser.add_argument(
"--cls_model",
required=True,
help="Path of Classification model of PPOCR.")
parser.add_argument(
"--rec_model",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--rec_label_file",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--det_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
parser.add_argument(
"--cls_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
parser.add_argument(
"--rec_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_det_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.det_use_trt:
option.use_trt_backend()
#det_max_side_len 默认为960,当用户更改DET模型的max_side_len参数时请将此参数同时更改
det_max_side_len = 960
option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640],
[1, 3, det_max_side_len, det_max_side_len])
return option
def build_cls_option(args):
option = fd.RuntimeOption()
option.use_paddle_backend()
if args.device.lower() == "gpu":
option.use_gpu()
if args.cls_use_trt:
option.use_trt_backend()
option.set_trt_input_shape("x", [1, 3, 32, 100])
return option
def build_rec_option(args):
option = fd.RuntimeOption()
option.use_paddle_backend()
if args.device.lower() == "gpu":
option.use_gpu()
if args.rec_use_trt:
option.use_trt_backend()
option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
[1, 3, 48, 2000])
return option
args = parse_arguments()
#Det模型
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
#Cls模型
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
#Rec模型
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
rec_label_file = args.rec_label_file
#默认
det_model = fd.vision.ocr.DBDetector("")
cls_model = fd.vision.ocr.Classifier()
rec_model = fd.vision.ocr.Recognizer()
#模型初始化
if (len(args.det_model) != 0):
det_runtime_option = build_det_option(args)
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=det_runtime_option)
if (len(args.cls_model) != 0):
cls_runtime_option = build_cls_option(args)
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=cls_runtime_option)
if (len(args.rec_model) != 0):
rec_runtime_option = build_rec_option(args)
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_runtime_option)
ppocrsysv2 = fd.vision.ocr.PPOCRSystemv2(
ocr_det=det_model._model,
ocr_cls=cls_model._model,
ocr_rec=rec_model._model)
# 预测图片准备
im = cv2.imread(args.image)
#预测并打印结果
result = ppocrsysv2.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_ppocr(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,139 @@
# PPOCRSystemv3 C++部署示例
本目录下提供`infer.cc`快速完成PPOCRSystemv3在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/quick_start)
以Linux上CPU推理为例在本目录执行如下命令即可完成编译测试
```
mkdir build
cd build
wget https://https://bj.bcebos.com/paddlehub/fastdeploy/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
tar xvf fastdeploy-linux-x64-0.2.0.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0
make -j
# 下载模型,图片和label文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar xvf ch_PP-OCRv3_rec_infer.tar
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
# CPU推理
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
# GPU推理
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1
# GPU上TensorRT推理
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2
# OCR还支持det/cls/rec三个模型的组合使用例如当我们不想使用cls模型的时候只需要给cls模型路径的位置传入一个空的字符串, 例子如下
./infer_demo ./ch_PP-OCRv3_det_infer "" ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
```
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
## PPOCRSystemv3 C++接口
### PPOCRSystemv3类
```
fastdeploy::application::ocrsystem::PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
```
PPOCRSystemv3 的初始化,由检测,分类和识别模型串联构成
**参数**
> * **DBDetector**(model): OCR中的检测模型
> * **Classifier**(model): OCR中的分类模型
> * **Recognizer**(model): OCR中的识别模型
#### Predict函数
> ```
> std::vector<std::vector<fastdeploy::vision::OCRResult>> ocr_results =
> PPOCRSystemv3.Predict(std::vector<cv::Mat> cv_all_imgs);
>
> ```
>
> 模型预测接口,输入一个可装入多张图片的图片列表,后可输出检测结果。
>
> **参数**
>
> > * **cv_all_imgs**: 输入图像注意需为HWCBGR格式
> > * **ocr_results**: OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果, OCRResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
## DBDetector C++接口
### DBDetector类
```
fastdeploy::vision::ocr::DBDetector(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
```
DBDetector模型加载和初始化其中模型为paddle模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为Paddle格式
### Classifier类与DBDetector类相同
### Recognizer类
```
Recognizer(const std::string& model_file,
const std::string& params_file = "",
const std::string& label_path = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
```
Recognizer类初始化时,需要在label_path参数中,输入识别模型所需的label文件其他参数均与DBDetector类相同
**参数**
> * **label_path**(str): 识别模型的label文件路径
### 类成员变量
#### DBDetector预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸当长边超出这个值时会将长边resize到这个大小短边等比例缩放,默认为960
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值默认为0.3
> > * **det_db_box_thresh**(double): DB模型输出框的阈值低于此值的预测框会被丢弃默认为0.6
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例默认为1.5
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow即求polygon区域的平均分数的方式
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
#### Classifier预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值输入的图片将被翻转默认为0.9
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,290 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseCpu();
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseCpu();
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseCpu();
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv3_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
//准备模型
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseGpu();
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseGpu();
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseGpu();
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv3_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void TrtInfer(const std::string& det_model_dir,
const std::string& cls_model_dir,
const std::string& rec_model_dir,
const std::string& rec_label_file,
const std::string& image_file) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto rec_label = rec_label_file;
fastdeploy::vision::ocr::DBDetector det_model;
fastdeploy::vision::ocr::Classifier cls_model;
fastdeploy::vision::ocr::Recognizer rec_model;
//准备模型
if (!det_model_dir.empty()) {
auto det_option = fastdeploy::RuntimeOption();
det_option.UseGpu();
det_option.UseTrtBackend();
det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640},
{1, 3, 960, 960});
det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option);
if (!det_model.Initialized()) {
std::cerr << "Failed to initialize det_model." << std::endl;
return;
}
}
if (!cls_model_dir.empty()) {
auto cls_option = fastdeploy::RuntimeOption();
cls_option.UseGpu();
cls_option.UseTrtBackend();
cls_option.SetTrtInputShape("x", {1, 3, 48, 192});
cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option);
if (!cls_model.Initialized()) {
std::cerr << "Failed to initialize cls_model." << std::endl;
return;
}
}
if (!rec_model_dir.empty()) {
auto rec_option = fastdeploy::RuntimeOption();
rec_option.UseGpu();
rec_option.UseTrtBackend();
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320},
{1, 3, 48, 2000});
rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label, rec_option);
if (!rec_model.Initialized()) {
std::cerr << "Failed to initialize rec_model." << std::endl;
return;
}
}
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
&det_model, &cls_model, &rec_model);
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult res;
//开始预测
if (!ocrv3_app.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
//输出预测信息
std::cout << res.Str() << std::endl;
//可视化
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
cv::imwrite("vis_result.jpg", vis_img);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 7) {
std::cout << "Usage: infer_demo path/to/det_model path/to/cls_model "
"path/to/rec_model path/to/rec_label_file path/to/image "
"run_option, "
"e.g ./infer_demo ./ch_PP-OCRv3_det_infer "
"./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer "
"./ppocr_keys_v1.txt ./12.jpg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[6]) == 0) {
CpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
} else if (std::atoi(argv[6]) == 1) {
GpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
} else if (std::atoi(argv[6]) == 2) {
TrtInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
}
return 0;
}

View File

@@ -0,0 +1,131 @@
# PPOCRSystemv3 Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start)
本目录下提供`infer.py`快速完成PPOCRSystemv3在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```
# 下载模型,图片和label文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar xvf ch_PP-OCRv3_rec_infer.tar
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vison/ocr/PPOCRSystemv3/python/
# CPU推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
# GPU推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu
# GPU上使用TensorRT推理
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --det_use_trt True --cls_use_trt True --rec_use_trt True
# OCR还支持det/cls/rec三个模型的组合使用例如当我们不想使用cls模型的时候只需要给--cls_model传入一个空的字符串, 例子如下:
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model "" --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
```
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
## PPOCRSystemv3 Python接口
```
fastdeploy.vision.ocr.PPOCRSystemv3(ocr_det = det_model._model, ocr_cls = cls_model._model, ocr_rec = rec_model._model)
```
PPOCRSystemv3的初始化,输入的参数是检测模型,分类模型和识别模型
**参数**
> * **ocr_det**(model): OCR中的检测模型
> * **ocr_cls**(model): OCR中的分类模型
> * **ocr_rec**(model): OCR中的识别模型
### predict函数
> ```
> result = PPOCRSystemv3.predict(img_list)
> ```
>
> 模型预测接口输入的是一个可包含多个图像的list
>
> **参数**
>
> > * **img_list**(list[np.ndarray]): 输入数据的list每张图片注意需为HWCBGR格式
> > * **result**(float): OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果,
> **返回**
>
> > 返回`fastdeploy.vision.OCRResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## DBDetector Python接口
### DBDetector类
```
fastdeploy.vision.ocr.DBDetector(model_file, params_file, runtime_option=None, model_format=Frontend.PADDLE)
```
DBDetector模型加载和初始化其中模型为paddle模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为PADDLE格式
### Classifier类与DBDetector类相同
### Recognizer类
```
fastdeploy.vision.ocr.Recognizer(rec_model_file,rec_params_file,rec_label_file,
runtime_option=rec_runtime_option,model_format=Frontend.PADDLE)
```
Recognizer类初始化时,需要在rec_label_file参数中,输入识别模型所需的label文件路径其他参数均与DBDetector类相同
**参数**
> * **label_path**(str): 识别模型的label文件路径
### 类成员变量
#### DBDetector预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸当长边超出这个值时会将长边resize到这个大小短边等比例缩放,默认为960
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值默认为0.3
> > * **det_db_box_thresh**(double): DB模型输出框的阈值低于此值的预测框会被丢弃默认为0.6
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例默认为1.5
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow即求polygon区域的平均分数的方式
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
#### Classifier预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值输入的图片将被翻转默认为0.9
## 其它文档
- [YOLOv5 模型介绍](..)
- [YOLOv5 C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,145 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model", required=True, help="Path of Detection model of PPOCR.")
parser.add_argument(
"--cls_model",
required=True,
help="Path of Classification model of PPOCR.")
parser.add_argument(
"--rec_model",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--rec_label_file",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--det_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
parser.add_argument(
"--cls_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
parser.add_argument(
"--rec_use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_det_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.det_use_trt:
option.use_trt_backend()
#det_max_side_len 默认为960,当用户更改DET模型的max_side_len参数时请将此参数同时更改
det_max_side_len = 960
option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640],
[1, 3, det_max_side_len, det_max_side_len])
return option
def build_cls_option(args):
option = fd.RuntimeOption()
option.use_paddle_backend()
if args.device.lower() == "gpu":
option.use_gpu()
if args.cls_use_trt:
option.use_trt_backend()
option.set_trt_input_shape("x", [1, 3, 32, 100])
return option
def build_rec_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.rec_use_trt:
option.use_trt_backend()
option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
[1, 3, 48, 2000])
return option
args = parse_arguments()
#Det模型
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
#Cls模型
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
#Rec模型
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
rec_label_file = args.rec_label_file
#默认
det_model = fd.vision.ocr.DBDetector("")
cls_model = fd.vision.ocr.Classifier()
rec_model = fd.vision.ocr.Recognizer()
#模型初始化
if (len(args.det_model) != 0):
det_runtime_option = build_det_option(args)
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=det_runtime_option)
if (len(args.cls_model) != 0):
cls_runtime_option = build_cls_option(args)
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=cls_runtime_option)
if (len(args.rec_model) != 0):
rec_runtime_option = build_rec_option(args)
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_runtime_option)
ppocrsysv3 = fd.vision.ocr.PPOCRSystemv3(
ocr_det=det_model._model,
ocr_cls=cls_model._model,
ocr_rec=rec_model._model)
# 预测图片准备
im = cv2.imread(args.image)
#预测并打印结果
result = ppocrsysv3.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_ppocr(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -0,0 +1,17 @@
# PaddleOCR 模型部署
## 模型版本说明
- [PaddleOCR Release/2.5](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5)
目前FastDeploy支持如下模型的部署
- [PaddleOCRv3系列模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md)
## 准备PaddleOCRv3部署模型
用户在[PP-OCR系列模型列表](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md)下载相应的的OCRv3系列推理模型即可.
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)

View File

@@ -20,6 +20,6 @@ from . import segmentation
from . import matting
from . import facedet
from . import faceid
from . import ocr
from . import evaluation
from .visualize import *

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .ppocr import PPOCRSystemv3
from .ppocr import PPOCRSystemv2
from .ppocr import DBDetector
from .ppocr import Classifier
from .ppocr import Recognizer

View File

@@ -0,0 +1,234 @@
# # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, Frontend
from .... import c_lib_wrap as C
class DBDetector(FastDeployModel):
def __init__(self,
model_file="",
params_file="",
runtime_option=None,
model_format=Frontend.PADDLE):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(DBDetector, self).__init__(runtime_option)
if (len(model_file) == 0):
self._model = C.vision.ocr.DBDetector()
else:
self._model = C.vision.ocr.DBDetector(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "DBDetector initialize failed."
# 一些跟DBDetector模型有关的属性封装
@property
def max_side_len(self):
return self._model.max_side_len
@property
def det_db_thresh(self):
return self._model.det_db_thresh
@property
def det_db_box_thresh(self):
return self._model.det_db_box_thresh
@property
def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio
@property
def det_db_score_mode(self):
return self._model.det_db_score_mode
@property
def use_dilation(self):
return self._model.use_dilation
@property
def is_scale(self):
return self._model.max_wh
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.max_side_len = value
@det_db_thresh.setter
def det_db_thresh(self, value):
assert isinstance(
value,
float), "The value to set `det_db_thresh` must be type of float."
self._model.det_db_thresh = value
@det_db_box_thresh.setter
def det_db_box_thresh(self, value):
assert isinstance(
value, float
), "The value to set `det_db_box_thresh` must be type of float."
self._model.det_db_box_thresh = value
@det_db_unclip_ratio.setter
def det_db_unclip_ratio(self, value):
assert isinstance(
value, float
), "The value to set `det_db_unclip_ratio` must be type of float."
self._model.det_db_unclip_ratio = value
@det_db_score_mode.setter
def det_db_score_mode(self, value):
assert isinstance(
value,
str), "The value to set `det_db_score_mode` must be type of str."
self._model.det_db_score_mode = value
@use_dilation.setter
def use_dilation(self, value):
assert isinstance(
value,
bool), "The value to set `use_dilation` must be type of bool."
self._model.use_dilation = value
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.is_scale = value
class Classifier(FastDeployModel):
def __init__(self,
model_file="",
params_file="",
runtime_option=None,
model_format=Frontend.PADDLE):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(Classifier, self).__init__(runtime_option)
if (len(model_file) == 0):
self._model = C.vision.ocr.Classifier()
else:
self._model = C.vision.ocr.Classifier(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "Classifier initialize failed."
@property
def cls_thresh(self):
return self._model.cls_thresh
@property
def cls_image_shape(self):
return self._model.cls_image_shape
@property
def cls_batch_num(self):
return self._model.cls_batch_num
@cls_thresh.setter
def cls_thresh(self, value):
assert isinstance(
value,
float), "The value to set `cls_thresh` must be type of float."
self._model.cls_thresh = value
@cls_image_shape.setter
def cls_image_shape(self, value):
assert isinstance(
value, list), "The value to set `cls_thresh` must be type of list."
self._model.cls_image_shape = value
@cls_batch_num.setter
def cls_batch_num(self, value):
assert isinstance(
value,
int), "The value to set `cls_batch_num` must be type of int."
self._model.cls_batch_num = value
class Recognizer(FastDeployModel):
def __init__(self,
model_file="",
params_file="",
label_path="",
runtime_option=None,
model_format=Frontend.PADDLE):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(Recognizer, self).__init__(runtime_option)
if (len(model_file) == 0):
self._model = C.vision.ocr.Recognizer()
else:
self._model = C.vision.ocr.Recognizer(
model_file, params_file, label_path, self._runtime_option,
model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "Recognizer initialize failed."
@property
def rec_img_h(self):
return self._model.rec_img_h
@property
def rec_img_w(self):
return self._model.rec_img_w
@property
def rec_batch_num(self):
return self._model.rec_batch_num
@rec_img_h.setter
def rec_img_h(self, value):
assert isinstance(
value, int), "The value to set `rec_img_h` must be type of int."
self._model.rec_img_h = value
@rec_img_w.setter
def rec_img_w(self, value):
assert isinstance(
value, int), "The value to set `rec_img_w` must be type of int."
self._model.rec_img_w = value
@rec_batch_num.setter
def rec_batch_num(self, value):
assert isinstance(
value,
int), "The value to set `rec_batch_num` must be type of int."
self._model.rec_batch_num = value
class PPOCRSystemv3(FastDeployModel):
def __init__(self, ocr_det=None, ocr_cls=None, ocr_rec=None):
self._model = C.vision.ocr.PPOCRSystemv3(ocr_det, ocr_cls, ocr_rec)
def predict(self, input_image):
return self._model.predict(input_image)
class PPOCRSystemv2(FastDeployModel):
def __init__(self, ocr_det=None, ocr_cls=None, ocr_rec=None):
self._model = C.vision.ocr.PPOCRSystemv2(ocr_det, ocr_cls, ocr_rec)
def predict(self, input_image):
return self._model.predict(input_image)

View File

@@ -40,3 +40,7 @@ def vis_matting_alpha(im_data,
remove_small_connected_area=False):
return C.vision.Visualize.vis_matting_alpha(im_data, matting_result,
remove_small_connected_area)
def vis_ppocr(im_data, det_result):
return C.vision.Visualize.vis_ppocr(im_data, det_result)