mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Add PaddleOCRv3 & PaddleOCRv2 Support (#139)
* Add PaddleOCR Support * Add PaddleOCR Support * Add PaddleOCRv3 Support * Add PaddleOCRv3 Support * Update README.md * Update README.md * Update README.md * Update README.md * Add PaddleOCRv3 Support * Add PaddleOCRv3 Supports * Add PaddleOCRv3 Suport * Fix Rec diff * Remove useless functions * Remove useless comments * Add PaddleOCRv2 Support
This commit is contained in:
@@ -29,6 +29,13 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
if (!option.enable_log_info) {
|
||||
config_.DisableGlogInfo();
|
||||
}
|
||||
if (!option.delete_pass_names.empty()) {
|
||||
auto pass_builder = config_.pass_builder();
|
||||
for (int i = 0; i < option.delete_pass_names.size(); i++) {
|
||||
FDINFO << "Delete pass : " << option.delete_pass_names[i] << std::endl;
|
||||
pass_builder->DeletePass(option.delete_pass_names[i]);
|
||||
}
|
||||
}
|
||||
if (option.cpu_thread_num <= 0) {
|
||||
config_.SetCpuMathLibraryNumThreads(8);
|
||||
} else {
|
||||
|
@@ -40,6 +40,8 @@ struct PaddleBackendOption {
|
||||
int gpu_mem_init_size = 100;
|
||||
// gpu device id
|
||||
int gpu_id = 0;
|
||||
|
||||
std::vector<std::string> delete_pass_names = {};
|
||||
};
|
||||
|
||||
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
|
||||
|
@@ -197,6 +197,9 @@ void RuntimeOption::EnablePaddleMKLDNN() { pd_enable_mkldnn = true; }
|
||||
|
||||
void RuntimeOption::DisablePaddleMKLDNN() { pd_enable_mkldnn = false; }
|
||||
|
||||
void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) {
|
||||
pd_delete_pass_names.push_back(pass_name);
|
||||
}
|
||||
void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
|
||||
|
||||
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; }
|
||||
@@ -307,6 +310,7 @@ void Runtime::CreatePaddleBackend() {
|
||||
pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size;
|
||||
pd_option.use_gpu = (option.device == Device::GPU) ? true : false;
|
||||
pd_option.gpu_id = option.device_id;
|
||||
pd_option.delete_pass_names = option.pd_delete_pass_names;
|
||||
pd_option.cpu_thread_num = option.cpu_thread_num;
|
||||
FDASSERT(option.model_format == Frontend::PADDLE,
|
||||
"PaddleBackend only support model format of Frontend::PADDLE.");
|
||||
|
@@ -70,6 +70,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
void EnablePaddleMKLDNN();
|
||||
// disable mkldnn while use paddle inference in CPU
|
||||
void DisablePaddleMKLDNN();
|
||||
// Enable delete in pass
|
||||
void DeletePaddleBackendPass(const std::string& delete_pass_name);
|
||||
|
||||
// enable debug information of paddle backend
|
||||
void EnablePaddleLogInfo();
|
||||
@@ -119,6 +121,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
bool pd_enable_mkldnn = true;
|
||||
bool pd_enable_log_info = false;
|
||||
int pd_mkldnn_cache_size = 1;
|
||||
std::vector<std::string> pd_delete_pass_names;
|
||||
|
||||
// ======Only for Trt Backend=======
|
||||
std::map<std::string, std::vector<int32_t>> trt_max_shape;
|
||||
|
@@ -35,6 +35,11 @@
|
||||
#include "fastdeploy/vision/faceid/contrib/partial_fc.h"
|
||||
#include "fastdeploy/vision/faceid/contrib/vpl.h"
|
||||
#include "fastdeploy/vision/matting/contrib/modnet.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||
#endif
|
||||
|
||||
|
@@ -302,5 +302,68 @@ std::string MattingResult::Str() {
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string OCRResult::Str() {
|
||||
std::string no_result;
|
||||
if (boxes.size() > 0) {
|
||||
std::string out;
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
out = out + "det boxes: [";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
out = out + "[" + std::to_string(boxes[n][i * 2]) + "," +
|
||||
std::to_string(boxes[n][i * 2 + 1]) + "]";
|
||||
|
||||
if (i != 3) {
|
||||
out = out + ",";
|
||||
}
|
||||
}
|
||||
out = out + "]";
|
||||
|
||||
if (rec_scores.size() > 0) {
|
||||
out = out + "rec text: " + text[n] + " rec scores:" +
|
||||
std::to_string(rec_scores[n]) + " ";
|
||||
}
|
||||
if (cls_label.size() > 0) {
|
||||
out = out + "cls label: " + std::to_string(cls_label[n]) +
|
||||
" cls score: " + std::to_string(cls_scores[n]);
|
||||
}
|
||||
out = out + "\n";
|
||||
}
|
||||
return out;
|
||||
|
||||
} else if (boxes.size() == 0 && rec_scores.size() > 0 &&
|
||||
cls_scores.size() > 0) {
|
||||
std::string out;
|
||||
for (int i = 0; i < rec_scores.size(); i++) {
|
||||
out = out + "rec text: " + text[i] + " rec scores:" +
|
||||
std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "cls label: " + std::to_string(cls_label[i]) +
|
||||
" cls score: " + std::to_string(cls_scores[i]);
|
||||
out = out + "\n";
|
||||
}
|
||||
return out;
|
||||
} else if (boxes.size() == 0 && rec_scores.size() == 0 &&
|
||||
cls_scores.size() > 0) {
|
||||
std::string out;
|
||||
for (int i = 0; i < cls_scores.size(); i++) {
|
||||
out = out + "cls label: " + std::to_string(cls_label[i]) +
|
||||
" cls score: " + std::to_string(cls_scores[i]);
|
||||
out = out + "\n";
|
||||
}
|
||||
return out;
|
||||
} else if (boxes.size() == 0 && rec_scores.size() > 0 &&
|
||||
cls_scores.size() == 0) {
|
||||
std::string out;
|
||||
for (int i = 0; i < rec_scores.size(); i++) {
|
||||
out = out + "rec text: " + text[i] + " rec scores:" +
|
||||
std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
no_result = no_result + "No Results!";
|
||||
return no_result;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -22,6 +22,7 @@ enum FASTDEPLOY_DECL ResultType {
|
||||
CLASSIFY,
|
||||
DETECTION,
|
||||
SEGMENTATION,
|
||||
OCR,
|
||||
FACE_DETECTION,
|
||||
FACE_RECOGNITION,
|
||||
MATTING
|
||||
@@ -59,6 +60,20 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
|
||||
std::string Str();
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL OCRResult : public BaseResult {
|
||||
std::vector<std::array<int, 8>> boxes;
|
||||
|
||||
std::vector<std::string> text;
|
||||
std::vector<float> rec_scores;
|
||||
|
||||
std::vector<float> cls_scores;
|
||||
std::vector<int32_t> cls_label;
|
||||
|
||||
ResultType type = ResultType::OCR;
|
||||
|
||||
std::string Str();
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL FaceDetectionResult : public BaseResult {
|
||||
// box: xmin, ymin, xmax, ymax
|
||||
std::vector<std::array<float, 4>> boxes;
|
||||
|
29
csrc/fastdeploy/vision/ocr/ocr_pybind.cc
Normal file
29
csrc/fastdeploy/vision/ocr/ocr_pybind.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindPPOCRModel(pybind11::module& m);
|
||||
void BindPPOCRSystemv3(pybind11::module& m);
|
||||
void BindPPOCRSystemv2(pybind11::module& m);
|
||||
|
||||
void BindOcr(pybind11::module& m) {
|
||||
auto ocr_module = m.def_submodule("ocr", "Module to deploy OCR models");
|
||||
BindPPOCRModel(ocr_module);
|
||||
BindPPOCRSystemv3(ocr_module);
|
||||
BindPPOCRSystemv2(ocr_module);
|
||||
}
|
||||
} // namespace fastdeploy
|
148
csrc/fastdeploy/vision/ocr/ppocr/classifier.cc
Normal file
148
csrc/fastdeploy/vision/ocr/ppocr/classifier.cc
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
Classifier::Classifier() {}
|
||||
Classifier::Classifier(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const Frontend& model_format) {
|
||||
if (model_format == Frontend::ONNX) {
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::TRT, Backend::ORT};
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
// Init
|
||||
bool Classifier::Initialize() {
|
||||
// pre&post process parameters
|
||||
cls_thresh = 0.9;
|
||||
cls_image_shape = {3, 48, 192};
|
||||
cls_batch_num = 1;
|
||||
mean = {0.485f, 0.456f, 0.406f};
|
||||
scale = {0.5f, 0.5f, 0.5f};
|
||||
is_scale = true;
|
||||
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OcrClassifierResizeImage(Mat* mat,
|
||||
const std::vector<int>& rec_image_shape) {
|
||||
int imgC = rec_image_shape[0];
|
||||
int imgH = rec_image_shape[1];
|
||||
int imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = float(mat->Width()) / float(mat->Height());
|
||||
|
||||
int resize_w;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
Resize::Run(mat, resize_w, imgH);
|
||||
|
||||
std::vector<float> value = {0, 0, 0};
|
||||
if (resize_w < imgW) {
|
||||
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
|
||||
}
|
||||
}
|
||||
|
||||
//预处理
|
||||
bool Classifier::Preprocess(Mat* mat, FDTensor* output) {
|
||||
// 1. cls resizes
|
||||
// 2. normalize
|
||||
// 3. batch_permute
|
||||
OcrClassifierResizeImage(mat, cls_image_shape);
|
||||
|
||||
Normalize::Run(mat, mean, scale, true);
|
||||
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
|
||||
mat->ShareWithTensor(output);
|
||||
output->shape.insert(output->shape.begin(), 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//后处理
|
||||
bool Classifier::Postprocess(FDTensor& infer_result, int& cls_labels,
|
||||
float& cls_scores) {
|
||||
std::vector<int64_t> output_shape = infer_result.shape;
|
||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
||||
|
||||
float* out_data = static_cast<float*>(infer_result.Data());
|
||||
|
||||
int label = std::distance(
|
||||
&out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]]));
|
||||
|
||||
float score =
|
||||
float(*std::max_element(&out_data[0], &out_data[output_shape[1]]));
|
||||
|
||||
cls_labels = label;
|
||||
cls_scores = score;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//预测
|
||||
bool Classifier::Predict(cv::Mat* img, int& cls_labels, float& cls_socres) {
|
||||
Mat mat(*img);
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0])) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Postprocess(output_tensors[0], cls_labels, cls_socres)) {
|
||||
FDERROR << "Failed to post process." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namesapce ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
64
csrc/fastdeploy/vision/ocr/ppocr/classifier.h
Normal file
64
csrc/fastdeploy/vision/ocr/ppocr/classifier.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
class FASTDEPLOY_DECL Classifier : public FastDeployModel {
|
||||
public:
|
||||
Classifier();
|
||||
// 当model_format为ONNX时,无需指定params_file
|
||||
// 当model_format为Paddle时,则需同时指定model_file & params_file
|
||||
Classifier(const std::string& model_file, const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
|
||||
// 定义模型的名称
|
||||
std::string ModelName() const { return "ppocr/ocr_cls"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
virtual bool Predict(cv::Mat* img, int& cls_labels, float& cls_socres);
|
||||
|
||||
// pre & post parameters
|
||||
float cls_thresh;
|
||||
std::vector<int> cls_image_shape;
|
||||
int cls_batch_num;
|
||||
|
||||
std::vector<float> mean;
|
||||
std::vector<float> scale;
|
||||
bool is_scale;
|
||||
|
||||
private:
|
||||
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
||||
bool Initialize();
|
||||
|
||||
// 输入图像预处理操作
|
||||
// FDTensor为预处理后的Tensor数据,传给后端进行推理
|
||||
bool Preprocess(Mat* img, FDTensor* output);
|
||||
|
||||
// 后端推理结果后处理,输出给用户
|
||||
// infer_result 为后端推理后的输出Tensor
|
||||
bool Postprocess(FDTensor& infer_result, int& cls_labels, float& cls_scores);
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
190
csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc
Normal file
190
csrc/fastdeploy/vision/ocr/ppocr/dbdetector.cc
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
DBDetector::DBDetector() {}
|
||||
DBDetector::DBDetector(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const Frontend& model_format) {
|
||||
if (model_format == Frontend::ONNX) {
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||
}
|
||||
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
// Init
|
||||
bool DBDetector::Initialize() {
|
||||
// pre&post process parameters
|
||||
max_side_len = 960;
|
||||
|
||||
det_db_thresh = 0.3;
|
||||
det_db_box_thresh = 0.6;
|
||||
det_db_unclip_ratio = 1.5;
|
||||
det_db_score_mode = "slow";
|
||||
use_dilation = false;
|
||||
|
||||
mean = {0.485f, 0.456f, 0.406f};
|
||||
scale = {0.229f, 0.224f, 0.225f};
|
||||
is_scale = true;
|
||||
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OcrDetectorResizeImage(Mat* img, int max_size_len, float& ratio_h,
|
||||
float& ratio_w) {
|
||||
int w = img->Width();
|
||||
int h = img->Height();
|
||||
|
||||
float ratio = 1.f;
|
||||
int max_wh = w >= h ? w : h;
|
||||
if (max_wh > max_size_len) {
|
||||
if (h > w) {
|
||||
ratio = float(max_size_len) / float(h);
|
||||
} else {
|
||||
ratio = float(max_size_len) / float(w);
|
||||
}
|
||||
}
|
||||
|
||||
int resize_h = int(float(h) * ratio);
|
||||
int resize_w = int(float(w) * ratio);
|
||||
|
||||
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
|
||||
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
|
||||
|
||||
Resize::Run(img, resize_w, resize_h);
|
||||
|
||||
ratio_h = float(resize_h) / float(h);
|
||||
ratio_w = float(resize_w) / float(w);
|
||||
}
|
||||
|
||||
//预处理
|
||||
bool DBDetector::Preprocess(
|
||||
Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
// Resize
|
||||
OcrDetectorResizeImage(mat, max_side_len, ratio_h, ratio_w);
|
||||
// Normalize
|
||||
Normalize::Run(mat, mean, scale, true);
|
||||
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
//-CHW
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
|
||||
mat->ShareWithTensor(output);
|
||||
output->shape.insert(output->shape.begin(), 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
//后处理
|
||||
bool DBDetector::Postprocess(
|
||||
FDTensor& infer_result, std::vector<std::vector<std::vector<int>>>* boxes,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info) {
|
||||
std::vector<int64_t> output_shape = infer_result.shape;
|
||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
||||
int n2 = output_shape[2];
|
||||
int n3 = output_shape[3];
|
||||
int n = n2 * n3;
|
||||
|
||||
float* out_data = static_cast<float*>(infer_result.Data());
|
||||
// prepare bitmap
|
||||
std::vector<float> pred(n, 0.0);
|
||||
std::vector<unsigned char> cbuf(n, ' ');
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
pred[i] = float(out_data[i]);
|
||||
cbuf[i] = (unsigned char)((out_data[i]) * 255);
|
||||
}
|
||||
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
|
||||
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
|
||||
|
||||
const double threshold = det_db_thresh * 255;
|
||||
const double maxvalue = 255;
|
||||
cv::Mat bit_map;
|
||||
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
||||
if (use_dilation) {
|
||||
cv::Mat dila_ele =
|
||||
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
|
||||
cv::dilate(bit_map, bit_map, dila_ele);
|
||||
}
|
||||
|
||||
post_processor_.BoxesFromBitmap(pred_map, boxes, bit_map, det_db_box_thresh,
|
||||
det_db_unclip_ratio, det_db_score_mode);
|
||||
|
||||
post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//预测
|
||||
bool DBDetector::Predict(
|
||||
cv::Mat* img, std::vector<std::vector<std::vector<int>>>* boxes_result) {
|
||||
Mat mat(*img);
|
||||
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
std::map<std::string, std::array<float, 2>> im_info;
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
im_info["input_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Postprocess(output_tensors[0], boxes_result, im_info)) {
|
||||
FDERROR << "Failed to post process." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namesapce ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
76
csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h
Normal file
76
csrc/fastdeploy/vision/ocr/ppocr/dbdetector.h
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
|
||||
public:
|
||||
DBDetector();
|
||||
|
||||
DBDetector(const std::string& model_file, const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
|
||||
// 定义模型的名称
|
||||
std::string ModelName() const { return "ppocr/ocr_det"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
virtual bool Predict(cv::Mat* im,
|
||||
std::vector<std::vector<std::vector<int>>>* boxes);
|
||||
|
||||
// pre&post process parameters
|
||||
int max_side_len;
|
||||
|
||||
float ratio_h{};
|
||||
float ratio_w{};
|
||||
|
||||
double det_db_thresh;
|
||||
double det_db_box_thresh;
|
||||
double det_db_unclip_ratio;
|
||||
std::string det_db_score_mode;
|
||||
bool use_dilation;
|
||||
|
||||
std::vector<float> mean;
|
||||
std::vector<float> scale;
|
||||
bool is_scale;
|
||||
|
||||
private:
|
||||
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
||||
bool Initialize();
|
||||
|
||||
// FDTensor为预处理后的Tensor数据,传给后端进行推理
|
||||
// im_info为预处理过程保存的数据,在后处理中需要用到
|
||||
bool Preprocess(Mat* mat, FDTensor* outputs,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
// 后端推理结果后处理,输出给用户
|
||||
bool Postprocess(FDTensor& infer_result,
|
||||
std::vector<std::vector<std::vector<int>>>* boxes,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info);
|
||||
|
||||
// OCR后处理类
|
||||
PostProcessor post_processor_;
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
58
csrc/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Normal file
58
csrc/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <pybind11/stl.h>
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPPOCRModel(pybind11::module& m) {
|
||||
// DBDetector
|
||||
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def(pybind11::init<>())
|
||||
|
||||
.def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len)
|
||||
.def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh)
|
||||
.def_readwrite("det_db_box_thresh",
|
||||
&vision::ocr::DBDetector::det_db_box_thresh)
|
||||
.def_readwrite("det_db_unclip_ratio",
|
||||
&vision::ocr::DBDetector::det_db_unclip_ratio)
|
||||
.def_readwrite("det_db_score_mode",
|
||||
&vision::ocr::DBDetector::det_db_score_mode)
|
||||
.def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation)
|
||||
.def_readwrite("mean", &vision::ocr::DBDetector::mean)
|
||||
.def_readwrite("scale", &vision::ocr::DBDetector::scale)
|
||||
.def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale);
|
||||
|
||||
// Classifier
|
||||
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def(pybind11::init<>())
|
||||
|
||||
.def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh)
|
||||
.def_readwrite("cls_image_shape",
|
||||
&vision::ocr::Classifier::cls_image_shape)
|
||||
.def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num);
|
||||
|
||||
// Recognizer
|
||||
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
|
||||
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
Frontend>())
|
||||
.def(pybind11::init<>())
|
||||
|
||||
.def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h)
|
||||
.def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w)
|
||||
.def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num);
|
||||
}
|
||||
} // namespace fastdeploy
|
56
csrc/fastdeploy/vision/ocr/ppocr/ocrsys_pybind.cc
Normal file
56
csrc/fastdeploy/vision/ocr/ppocr/ocrsys_pybind.cc
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <pybind11/stl.h>
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPPOCRSystemv3(pybind11::module& m) {
|
||||
// OCRSys
|
||||
pybind11::class_<application::ocrsystem::PPOCRSystemv3, FastDeployModel>(
|
||||
m, "PPOCRSystemv3")
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
|
||||
fastdeploy::vision::ocr::Classifier*,
|
||||
fastdeploy::vision::ocr::Recognizer*>())
|
||||
|
||||
.def("predict", [](application::ocrsystem::PPOCRSystemv3& self,
|
||||
pybind11::array& data) {
|
||||
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::OCRResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
void BindPPOCRSystemv2(pybind11::module& m) {
|
||||
// OCRSys
|
||||
pybind11::class_<application::ocrsystem::PPOCRSystemv2, FastDeployModel>(
|
||||
m, "PPOCRSystemv2")
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
|
||||
fastdeploy::vision::ocr::Classifier*,
|
||||
fastdeploy::vision::ocr::Recognizer*>())
|
||||
|
||||
.def("predict", [](application::ocrsystem::PPOCRSystemv2& self,
|
||||
pybind11::array& data) {
|
||||
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::OCRResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
130
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc
Normal file
130
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.cc
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace application {
|
||||
namespace ocrsystem {
|
||||
PPOCRSystemv2::PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec)
|
||||
: detector(ocr_det), classifier(ocr_cls), recognizer(ocr_rec) {}
|
||||
|
||||
void PPOCRSystemv2::Detect(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
||||
this->detector->Predict(img, &boxes);
|
||||
|
||||
// vector<vector>转array
|
||||
for (int i = 0; i < boxes.size(); i++) {
|
||||
std::array<int, 8> new_box;
|
||||
int k = 0;
|
||||
for (auto& vec : boxes[i]) {
|
||||
for (auto& e : vec) {
|
||||
new_box[k++] = e;
|
||||
}
|
||||
}
|
||||
(result->boxes).push_back(new_box);
|
||||
}
|
||||
}
|
||||
|
||||
void PPOCRSystemv2::Recognize(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
std::string rec_texts = "";
|
||||
float rec_text_scores = 0;
|
||||
|
||||
this->recognizer->rec_image_shape[1] =
|
||||
32; // OCRv2模型此处需要设置为32,其他与OCRv3一致
|
||||
this->recognizer->Predict(img, rec_texts, rec_text_scores);
|
||||
|
||||
result->text.push_back(rec_texts);
|
||||
result->rec_scores.push_back(rec_text_scores);
|
||||
}
|
||||
|
||||
void PPOCRSystemv2::Classify(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
int cls_label = 0;
|
||||
float cls_scores = 0;
|
||||
|
||||
this->classifier->Predict(img, cls_label, cls_scores);
|
||||
|
||||
result->cls_label.push_back(cls_label);
|
||||
result->cls_scores.push_back(cls_scores);
|
||||
}
|
||||
|
||||
bool PPOCRSystemv2::Predict(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
if (this->detector->initialized == 0) { //没det
|
||||
//输入单张“小图片”给分类器
|
||||
if (this->classifier->initialized != 0) {
|
||||
this->Classify(img, result);
|
||||
//摆正单张图像
|
||||
if ((result->cls_label)[0] % 2 == 1 &&
|
||||
(result->cls_scores)[0] > this->classifier->cls_thresh) {
|
||||
cv::rotate(*img, *img, 1);
|
||||
}
|
||||
}
|
||||
//输入单张“小图片”给识别器
|
||||
if (this->recognizer->initialized != 0) {
|
||||
this->Recognize(img, result);
|
||||
}
|
||||
|
||||
} else {
|
||||
//从DET模型开始
|
||||
//一张图,会输出多个“小图片”,送给后续模型
|
||||
this->Detect(img, result);
|
||||
std::cout << "Finish Det Prediction!" << std::endl;
|
||||
// crop image
|
||||
std::vector<cv::Mat> img_list;
|
||||
|
||||
for (int j = 0; j < (result->boxes).size(); j++) {
|
||||
cv::Mat crop_img;
|
||||
crop_img =
|
||||
fastdeploy::vision::ocr::GetRotateCropImage(*img, (result->boxes)[j]);
|
||||
img_list.push_back(crop_img);
|
||||
}
|
||||
// cls
|
||||
if (this->classifier->initialized != 0) {
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
this->Classify(&img_list[0], result);
|
||||
}
|
||||
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
if ((result->cls_label)[i] % 2 == 1 &&
|
||||
(result->cls_scores)[i] > this->classifier->cls_thresh) {
|
||||
std::cout << "Rotate this image " << std::endl;
|
||||
cv::rotate(img_list[i], img_list[i], 1);
|
||||
}
|
||||
}
|
||||
std::cout << "Finish Cls Prediction!" << std::endl;
|
||||
}
|
||||
// rec
|
||||
if (this->recognizer->initialized != 0) {
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
this->Recognize(&img_list[i], result);
|
||||
}
|
||||
std::cout << "Finish Rec Prediction!" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
} // namesapce ocrsystem
|
||||
} // namespace application
|
||||
} // namespace fastdeploy
|
52
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h
Normal file
52
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v2.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace application {
|
||||
namespace ocrsystem {
|
||||
|
||||
class FASTDEPLOY_DECL PPOCRSystemv2 : public FastDeployModel {
|
||||
public:
|
||||
PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector* detector = nullptr;
|
||||
fastdeploy::vision::ocr::Classifier* classifier = nullptr;
|
||||
fastdeploy::vision::ocr::Recognizer* recognizer = nullptr;
|
||||
|
||||
bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
|
||||
private:
|
||||
void Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
void Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
void Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
};
|
||||
|
||||
} // namespace ocrsystem
|
||||
} // namespace application
|
||||
} // namespace fastdeploy
|
128
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc
Normal file
128
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.cc
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace application {
|
||||
namespace ocrsystem {
|
||||
PPOCRSystemv3::PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec)
|
||||
: detector(ocr_det), classifier(ocr_cls), recognizer(ocr_rec) {}
|
||||
|
||||
void PPOCRSystemv3::Detect(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
||||
this->detector->Predict(img, &boxes);
|
||||
|
||||
// vector<vector>转array
|
||||
for (int i = 0; i < boxes.size(); i++) {
|
||||
std::array<int, 8> new_box;
|
||||
int k = 0;
|
||||
for (auto& vec : boxes[i]) {
|
||||
for (auto& e : vec) {
|
||||
new_box[k++] = e;
|
||||
}
|
||||
}
|
||||
(result->boxes).push_back(new_box);
|
||||
}
|
||||
}
|
||||
|
||||
void PPOCRSystemv3::Recognize(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
std::string rec_texts = "";
|
||||
float rec_text_scores = 0;
|
||||
|
||||
this->recognizer->Predict(img, rec_texts, rec_text_scores);
|
||||
|
||||
result->text.push_back(rec_texts);
|
||||
result->rec_scores.push_back(rec_text_scores);
|
||||
}
|
||||
|
||||
void PPOCRSystemv3::Classify(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
int cls_label = 0;
|
||||
float cls_scores = 0;
|
||||
|
||||
this->classifier->Predict(img, cls_label, cls_scores);
|
||||
|
||||
result->cls_label.push_back(cls_label);
|
||||
result->cls_scores.push_back(cls_scores);
|
||||
}
|
||||
|
||||
bool PPOCRSystemv3::Predict(cv::Mat* img,
|
||||
fastdeploy::vision::OCRResult* result) {
|
||||
if (this->detector->initialized == 0) { //没det
|
||||
//输入单张“小图片”给分类器
|
||||
if (this->classifier->initialized != 0) {
|
||||
this->Classify(img, result);
|
||||
//摆正单张图像
|
||||
if ((result->cls_label)[0] % 2 == 1 &&
|
||||
(result->cls_scores)[0] > this->classifier->cls_thresh) {
|
||||
cv::rotate(*img, *img, 1);
|
||||
}
|
||||
}
|
||||
//输入单张“小图片”给识别器
|
||||
if (this->recognizer->initialized != 0) {
|
||||
this->Recognize(img, result);
|
||||
}
|
||||
|
||||
} else {
|
||||
//从DET模型开始
|
||||
//一张图,会输出多个“小图片”,送给后续模型
|
||||
this->Detect(img, result);
|
||||
std::cout << "Finish Det Prediction!" << std::endl;
|
||||
// crop image
|
||||
std::vector<cv::Mat> img_list;
|
||||
|
||||
for (int j = 0; j < (result->boxes).size(); j++) {
|
||||
cv::Mat crop_img;
|
||||
crop_img =
|
||||
fastdeploy::vision::ocr::GetRotateCropImage(*img, (result->boxes)[j]);
|
||||
img_list.push_back(crop_img);
|
||||
}
|
||||
// cls
|
||||
if (this->classifier->initialized != 0) {
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
this->Classify(&img_list[0], result);
|
||||
}
|
||||
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
if ((result->cls_label)[i] % 2 == 1 &&
|
||||
(result->cls_scores)[i] > this->classifier->cls_thresh) {
|
||||
std::cout << "Rotate this image " << std::endl;
|
||||
cv::rotate(img_list[i], img_list[i], 1);
|
||||
}
|
||||
}
|
||||
std::cout << "Finish Cls Prediction!" << std::endl;
|
||||
}
|
||||
// rec
|
||||
if (this->recognizer->initialized != 0) {
|
||||
for (int i = 0; i < img_list.size(); i++) {
|
||||
this->Recognize(&img_list[i], result);
|
||||
}
|
||||
std::cout << "Finish Rec Prediction!" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
} // namesapce ocrsystem
|
||||
} // namespace application
|
||||
} // namespace fastdeploy
|
52
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h
Normal file
52
csrc/fastdeploy/vision/ocr/ppocr/ppocr_system_v3.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace application {
|
||||
namespace ocrsystem {
|
||||
|
||||
class FASTDEPLOY_DECL PPOCRSystemv3 : public FastDeployModel {
|
||||
public:
|
||||
PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector* detector = nullptr;
|
||||
fastdeploy::vision::ocr::Classifier* classifier = nullptr;
|
||||
fastdeploy::vision::ocr::Recognizer* recognizer = nullptr;
|
||||
|
||||
bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
|
||||
private:
|
||||
void Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
void Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
void Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||
};
|
||||
|
||||
} // namespace ocrsystem
|
||||
} // namespace application
|
||||
} // namespace fastdeploy
|
207
csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc
Normal file
207
csrc/fastdeploy/vision/ocr/ppocr/recognizer.cc
Normal file
@@ -0,0 +1,207 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
std::vector<std::string> ReadDict(const std::string& path) {
|
||||
std::ifstream in(path);
|
||||
std::string line;
|
||||
std::vector<std::string> m_vec;
|
||||
if (in) {
|
||||
while (getline(in, line)) {
|
||||
m_vec.push_back(line);
|
||||
}
|
||||
} else {
|
||||
std::cout << "no such label file: " << path << ", exit the program..."
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
return m_vec;
|
||||
}
|
||||
|
||||
Recognizer::Recognizer() {}
|
||||
|
||||
Recognizer::Recognizer(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& label_path,
|
||||
const RuntimeOption& custom_option,
|
||||
const Frontend& model_format) {
|
||||
if (model_format == Frontend::ONNX) {
|
||||
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
|
||||
} else {
|
||||
// NOTE:此模型暂不支持paddle-inference-Gpu推理
|
||||
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
||||
}
|
||||
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
// Recognizer在使用CPU推理,并把PaddleInference作为推理后端时,需要删除以下2个pass//
|
||||
runtime_option.DeletePaddleBackendPass("matmul_transpose_reshape_fuse_pass");
|
||||
runtime_option.DeletePaddleBackendPass(
|
||||
"matmul_transpose_reshape_mkldnn_fuse_pass");
|
||||
|
||||
initialized = Initialize();
|
||||
|
||||
// init label_lsit
|
||||
label_list = ReadDict(label_path);
|
||||
label_list.insert(label_list.begin(), "#"); // blank char for ctc
|
||||
label_list.push_back(" ");
|
||||
}
|
||||
|
||||
// Init
|
||||
bool Recognizer::Initialize() {
|
||||
// pre&post process parameters
|
||||
rec_batch_num = 1;
|
||||
rec_img_h = 48;
|
||||
rec_img_w = 320;
|
||||
rec_image_shape = {3, rec_img_h, rec_img_w};
|
||||
|
||||
mean = {0.5f, 0.5f, 0.5f};
|
||||
scale = {0.5f, 0.5f, 0.5f};
|
||||
is_scale = true;
|
||||
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OcrRecognizerResizeImage(Mat* mat, float wh_ratio,
|
||||
const std::vector<int>& rec_image_shape) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
imgW = int(imgH * wh_ratio);
|
||||
|
||||
float ratio = float(mat->Width()) / float(mat->Height());
|
||||
int resize_w;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
Resize::Run(mat, resize_w, imgH);
|
||||
|
||||
std::vector<float> value = {127, 127, 127};
|
||||
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
|
||||
}
|
||||
|
||||
//预处理
|
||||
bool Recognizer::Preprocess(Mat* mat, FDTensor* output,
|
||||
const std::vector<int>& rec_image_shape) {
|
||||
int imgH = rec_image_shape[1];
|
||||
int imgW = rec_image_shape[2];
|
||||
float wh_ratio = imgW * 1.0 / imgH;
|
||||
|
||||
float ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
|
||||
wh_ratio = std::max(wh_ratio, ori_wh_ratio);
|
||||
|
||||
OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape);
|
||||
|
||||
Normalize::Run(mat, mean, scale, true);
|
||||
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
|
||||
mat->ShareWithTensor(output);
|
||||
output->shape.insert(output->shape.begin(), 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//后处理
|
||||
bool Recognizer::Postprocess(FDTensor& infer_result, std::string& rec_texts,
|
||||
float& rec_text_scores) {
|
||||
std::vector<int64_t> output_shape = infer_result.shape;
|
||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
||||
|
||||
float* out_data = static_cast<float*>(infer_result.Data());
|
||||
|
||||
std::string str_res;
|
||||
int argmax_idx;
|
||||
int last_index = 0;
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int n = 0; n < output_shape[1]; n++) {
|
||||
argmax_idx = int(
|
||||
std::distance(&out_data[n * output_shape[2]],
|
||||
std::max_element(&out_data[n * output_shape[2]],
|
||||
&out_data[(n + 1) * output_shape[2]])));
|
||||
|
||||
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
|
||||
&out_data[(n + 1) * output_shape[2]]));
|
||||
|
||||
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
|
||||
score += max_value;
|
||||
count += 1;
|
||||
str_res += label_list[argmax_idx];
|
||||
}
|
||||
last_index = argmax_idx;
|
||||
}
|
||||
|
||||
score /= count;
|
||||
|
||||
rec_texts = str_res;
|
||||
rec_text_scores = score;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//预测
|
||||
bool Recognizer::Predict(cv::Mat* img, std::string& rec_texts,
|
||||
float& rec_text_scores) {
|
||||
Mat mat(*img);
|
||||
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Postprocess(output_tensors[0], rec_texts, rec_text_scores)) {
|
||||
FDERROR << "Failed to post process." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namesapce ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
69
csrc/fastdeploy/vision/ocr/ppocr/recognizer.h
Normal file
69
csrc/fastdeploy/vision/ocr/ppocr/recognizer.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
|
||||
public:
|
||||
Recognizer();
|
||||
// 当model_format为ONNX时,无需指定params_file
|
||||
// 当model_format为Paddle时,则需同时指定model_file & params_file
|
||||
Recognizer(const std::string& model_file, const std::string& params_file = "",
|
||||
const std::string& label_path = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
|
||||
// 定义模型的名称
|
||||
std::string ModelName() const { return "ppocr/ocr_rec"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
virtual bool Predict(cv::Mat* img, std::string& rec_texts,
|
||||
float& rec_text_scores);
|
||||
|
||||
// pre & post parameters
|
||||
std::vector<std::string> label_list;
|
||||
int rec_batch_num;
|
||||
int rec_img_h;
|
||||
int rec_img_w;
|
||||
std::vector<int> rec_image_shape;
|
||||
|
||||
std::vector<float> mean;
|
||||
std::vector<float> scale;
|
||||
bool is_scale;
|
||||
|
||||
private:
|
||||
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
||||
bool Initialize();
|
||||
|
||||
// 输入图像预处理操作
|
||||
bool Preprocess(Mat* img, FDTensor* outputs,
|
||||
const std::vector<int>& rec_image_shape);
|
||||
|
||||
// 后端推理结果后处理,输出给用户
|
||||
// infer_result 为后端推理后的输出Tensor
|
||||
bool Postprocess(FDTensor& infer_result, std::string& rec_texts,
|
||||
float& rec_text_scores);
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
4171
csrc/fastdeploy/vision/ocr/ppocr/utils/clipper.cc
Normal file
4171
csrc/fastdeploy/vision/ocr/ppocr/utils/clipper.cc
Normal file
File diff suppressed because it is too large
Load Diff
425
csrc/fastdeploy/vision/ocr/ppocr/utils/clipper.h
Normal file
425
csrc/fastdeploy/vision/ocr/ppocr/utils/clipper.h
Normal file
@@ -0,0 +1,425 @@
|
||||
/*******************************************************************************
|
||||
* *
|
||||
* Author : Angus Johnson *
|
||||
* Version : 6.4.2 *
|
||||
* Date : 27 February 2017 *
|
||||
* Website : http://www.angusj.com *
|
||||
* Copyright : Angus Johnson 2010-2017 *
|
||||
* *
|
||||
* License: *
|
||||
* Use, modification & distribution is subject to Boost Software License Ver 1. *
|
||||
* http://www.boost.org/LICENSE_1_0.txt *
|
||||
* *
|
||||
* Attributions: *
|
||||
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
|
||||
* "A generic solution to polygon clipping" *
|
||||
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
|
||||
* http://portal.acm.org/citation.cfm?id=129906 *
|
||||
* *
|
||||
* Computer graphics and geometric modeling: implementation and algorithms *
|
||||
* By Max K. Agoston *
|
||||
* Springer; 1 edition (January 4, 2005) *
|
||||
* http://books.google.com/books?q=vatti+clipping+agoston *
|
||||
* *
|
||||
* See also: *
|
||||
* "Polygon Offsetting by Computing Winding Numbers" *
|
||||
* Paper no. DETC2005-85513 pp. 565-575 *
|
||||
* ASME 2005 International Design Engineering Technical Conferences *
|
||||
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
|
||||
* September 24-28, 2005 , Long Beach, California, USA *
|
||||
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
|
||||
* *
|
||||
*******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef clipper_hpp
|
||||
#define clipper_hpp
|
||||
|
||||
#define CLIPPER_VERSION "6.4.2"
|
||||
|
||||
// use_int32: When enabled 32bit ints are used instead of 64bit ints. This
|
||||
// improve performance but coordinate values are limited to the range +/- 46340
|
||||
//#define use_int32
|
||||
|
||||
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
|
||||
//#define use_xyz
|
||||
|
||||
// use_lines: Enables line clipping. Adds a very minor cost to performance.
|
||||
#define use_lines
|
||||
|
||||
// use_deprecated: Enables temporary support for the obsolete functions
|
||||
//#define use_deprecated
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <ostream>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace ClipperLib {
|
||||
|
||||
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
|
||||
enum PolyType { ptSubject, ptClip };
|
||||
// By far the most widely used winding rules for polygon filling are
|
||||
// EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
|
||||
// Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
|
||||
// see http://glprogramming.com/red/chapter11.html
|
||||
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
|
||||
|
||||
#ifdef use_int32
|
||||
typedef int cInt;
|
||||
static cInt const loRange = 0x7FFF;
|
||||
static cInt const hiRange = 0x7FFF;
|
||||
#else
|
||||
typedef signed long long cInt;
|
||||
static cInt const loRange = 0x3FFFFFFF;
|
||||
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
|
||||
typedef signed long long long64; // used by Int128 class
|
||||
typedef unsigned long long ulong64;
|
||||
|
||||
#endif
|
||||
|
||||
struct IntPoint {
|
||||
cInt X;
|
||||
cInt Y;
|
||||
#ifdef use_xyz
|
||||
cInt Z;
|
||||
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0) : X(x), Y(y), Z(z){};
|
||||
#else
|
||||
IntPoint(cInt x = 0, cInt y = 0) : X(x), Y(y){};
|
||||
#endif
|
||||
|
||||
friend inline bool operator==(const IntPoint &a, const IntPoint &b) {
|
||||
return a.X == b.X && a.Y == b.Y;
|
||||
}
|
||||
friend inline bool operator!=(const IntPoint &a, const IntPoint &b) {
|
||||
return a.X != b.X || a.Y != b.Y;
|
||||
}
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
typedef std::vector<IntPoint> Path;
|
||||
typedef std::vector<Path> Paths;
|
||||
|
||||
inline Path &operator<<(Path &poly, const IntPoint &p) {
|
||||
poly.push_back(p);
|
||||
return poly;
|
||||
}
|
||||
inline Paths &operator<<(Paths &polys, const Path &p) {
|
||||
polys.push_back(p);
|
||||
return polys;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &s, const IntPoint &p);
|
||||
std::ostream &operator<<(std::ostream &s, const Path &p);
|
||||
std::ostream &operator<<(std::ostream &s, const Paths &p);
|
||||
|
||||
struct DoublePoint {
|
||||
double X;
|
||||
double Y;
|
||||
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
|
||||
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
#ifdef use_xyz
|
||||
typedef void (*ZFillCallback)(IntPoint &e1bot, IntPoint &e1top, IntPoint &e2bot,
|
||||
IntPoint &e2top, IntPoint &pt);
|
||||
#endif
|
||||
|
||||
enum InitOptions {
|
||||
ioReverseSolution = 1,
|
||||
ioStrictlySimple = 2,
|
||||
ioPreserveCollinear = 4
|
||||
};
|
||||
enum JoinType { jtSquare, jtRound, jtMiter };
|
||||
enum EndType {
|
||||
etClosedPolygon,
|
||||
etClosedLine,
|
||||
etOpenButt,
|
||||
etOpenSquare,
|
||||
etOpenRound
|
||||
};
|
||||
|
||||
class PolyNode;
|
||||
typedef std::vector<PolyNode *> PolyNodes;
|
||||
|
||||
class PolyNode {
|
||||
public:
|
||||
PolyNode();
|
||||
virtual ~PolyNode(){};
|
||||
Path Contour;
|
||||
PolyNodes Childs;
|
||||
PolyNode *Parent;
|
||||
PolyNode *GetNext() const;
|
||||
bool IsHole() const;
|
||||
bool IsOpen() const;
|
||||
int ChildCount() const;
|
||||
|
||||
private:
|
||||
// PolyNode& operator =(PolyNode& other);
|
||||
unsigned Index; // node index in Parent.Childs
|
||||
bool m_IsOpen;
|
||||
JoinType m_jointype;
|
||||
EndType m_endtype;
|
||||
PolyNode *GetNextSiblingUp() const;
|
||||
void AddChild(PolyNode &child);
|
||||
friend class Clipper; // to access Index
|
||||
friend class ClipperOffset;
|
||||
};
|
||||
|
||||
class PolyTree : public PolyNode {
|
||||
public:
|
||||
~PolyTree() { Clear(); };
|
||||
PolyNode *GetFirst() const;
|
||||
void Clear();
|
||||
int Total() const;
|
||||
|
||||
private:
|
||||
// PolyTree& operator =(PolyTree& other);
|
||||
PolyNodes AllNodes;
|
||||
friend class Clipper; // to access AllNodes
|
||||
};
|
||||
|
||||
bool Orientation(const Path &poly);
|
||||
double Area(const Path &poly);
|
||||
int PointInPolygon(const IntPoint &pt, const Path &path);
|
||||
|
||||
void SimplifyPolygon(const Path &in_poly, Paths &out_polys,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
|
||||
|
||||
void CleanPolygon(const Path &in_poly, Path &out_poly, double distance = 1.415);
|
||||
void CleanPolygon(Path &poly, double distance = 1.415);
|
||||
void CleanPolygons(const Paths &in_polys, Paths &out_polys,
|
||||
double distance = 1.415);
|
||||
void CleanPolygons(Paths &polys, double distance = 1.415);
|
||||
|
||||
void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution,
|
||||
bool pathIsClosed);
|
||||
void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution,
|
||||
bool pathIsClosed);
|
||||
void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution);
|
||||
|
||||
void PolyTreeToPaths(const PolyTree &polytree, Paths &paths);
|
||||
void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths);
|
||||
void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths);
|
||||
|
||||
void ReversePath(Path &p);
|
||||
void ReversePaths(Paths &p);
|
||||
|
||||
struct IntRect {
|
||||
cInt left;
|
||||
cInt top;
|
||||
cInt right;
|
||||
cInt bottom;
|
||||
};
|
||||
|
||||
// enums that are used internally ...
|
||||
enum EdgeSide { esLeft = 1, esRight = 2 };
|
||||
|
||||
// forward declarations (for stuff used internally) ...
|
||||
struct TEdge;
|
||||
struct IntersectNode;
|
||||
struct LocalMinimum;
|
||||
struct OutPt;
|
||||
struct OutRec;
|
||||
struct Join;
|
||||
|
||||
typedef std::vector<OutRec *> PolyOutList;
|
||||
typedef std::vector<TEdge *> EdgeList;
|
||||
typedef std::vector<Join *> JoinList;
|
||||
typedef std::vector<IntersectNode *> IntersectList;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// ClipperBase is the ancestor to the Clipper class. It should not be
|
||||
// instantiated directly. This class simply abstracts the conversion of sets of
|
||||
// polygon coordinates into edge objects that are stored in a LocalMinima list.
|
||||
class ClipperBase {
|
||||
public:
|
||||
ClipperBase();
|
||||
virtual ~ClipperBase();
|
||||
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
|
||||
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
|
||||
virtual void Clear();
|
||||
IntRect GetBounds();
|
||||
bool PreserveCollinear() { return m_PreserveCollinear; };
|
||||
void PreserveCollinear(bool value) { m_PreserveCollinear = value; };
|
||||
|
||||
protected:
|
||||
void DisposeLocalMinimaList();
|
||||
TEdge *AddBoundsToLML(TEdge *e, bool IsClosed);
|
||||
virtual void Reset();
|
||||
TEdge *ProcessBound(TEdge *E, bool IsClockwise);
|
||||
void InsertScanbeam(const cInt Y);
|
||||
bool PopScanbeam(cInt &Y);
|
||||
bool LocalMinimaPending();
|
||||
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
|
||||
OutRec *CreateOutRec();
|
||||
void DisposeAllOutRecs();
|
||||
void DisposeOutRec(PolyOutList::size_type index);
|
||||
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
|
||||
void DeleteFromAEL(TEdge *e);
|
||||
void UpdateEdgeIntoAEL(TEdge *&e);
|
||||
|
||||
typedef std::vector<LocalMinimum> MinimaList;
|
||||
MinimaList::iterator m_CurrentLM;
|
||||
MinimaList m_MinimaList;
|
||||
|
||||
bool m_UseFullRange;
|
||||
EdgeList m_edges;
|
||||
bool m_PreserveCollinear;
|
||||
bool m_HasOpenPaths;
|
||||
PolyOutList m_PolyOuts;
|
||||
TEdge *m_ActiveEdges;
|
||||
|
||||
typedef std::priority_queue<cInt> ScanbeamList;
|
||||
ScanbeamList m_Scanbeam;
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class Clipper : public virtual ClipperBase {
|
||||
public:
|
||||
Clipper(int initOptions = 0);
|
||||
bool Execute(ClipType clipType, Paths &solution,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
bool Execute(ClipType clipType, Paths &solution, PolyFillType subjFillType,
|
||||
PolyFillType clipFillType);
|
||||
bool Execute(ClipType clipType, PolyTree &polytree,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
bool Execute(ClipType clipType, PolyTree &polytree, PolyFillType subjFillType,
|
||||
PolyFillType clipFillType);
|
||||
bool ReverseSolution() { return m_ReverseOutput; };
|
||||
void ReverseSolution(bool value) { m_ReverseOutput = value; };
|
||||
bool StrictlySimple() { return m_StrictSimple; };
|
||||
void StrictlySimple(bool value) { m_StrictSimple = value; };
|
||||
// set the callback function for z value filling on intersections (otherwise Z
|
||||
// is 0)
|
||||
#ifdef use_xyz
|
||||
void ZFillFunction(ZFillCallback zFillFunc);
|
||||
#endif
|
||||
protected:
|
||||
virtual bool ExecuteInternal();
|
||||
|
||||
private:
|
||||
JoinList m_Joins;
|
||||
JoinList m_GhostJoins;
|
||||
IntersectList m_IntersectList;
|
||||
ClipType m_ClipType;
|
||||
typedef std::list<cInt> MaximaList;
|
||||
MaximaList m_Maxima;
|
||||
TEdge *m_SortedEdges;
|
||||
bool m_ExecuteLocked;
|
||||
PolyFillType m_ClipFillType;
|
||||
PolyFillType m_SubjFillType;
|
||||
bool m_ReverseOutput;
|
||||
bool m_UsingPolyTree;
|
||||
bool m_StrictSimple;
|
||||
#ifdef use_xyz
|
||||
ZFillCallback m_ZFill; // custom callback
|
||||
#endif
|
||||
void SetWindingCount(TEdge &edge);
|
||||
bool IsEvenOddFillType(const TEdge &edge) const;
|
||||
bool IsEvenOddAltFillType(const TEdge &edge) const;
|
||||
void InsertLocalMinimaIntoAEL(const cInt botY);
|
||||
void InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge);
|
||||
void AddEdgeToSEL(TEdge *edge);
|
||||
bool PopEdgeFromSEL(TEdge *&edge);
|
||||
void CopyAELToSEL();
|
||||
void DeleteFromSEL(TEdge *e);
|
||||
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
|
||||
bool IsContributing(const TEdge &edge) const;
|
||||
bool IsTopHorz(const cInt XPos);
|
||||
void DoMaxima(TEdge *e);
|
||||
void ProcessHorizontals();
|
||||
void ProcessHorizontal(TEdge *horzEdge);
|
||||
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
|
||||
OutPt *AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
|
||||
OutRec *GetOutRec(int idx);
|
||||
void AppendPolygon(TEdge *e1, TEdge *e2);
|
||||
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
|
||||
OutPt *AddOutPt(TEdge *e, const IntPoint &pt);
|
||||
OutPt *GetLastOutPt(TEdge *e);
|
||||
bool ProcessIntersections(const cInt topY);
|
||||
void BuildIntersectList(const cInt topY);
|
||||
void ProcessIntersectList();
|
||||
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
|
||||
void BuildResult(Paths &polys);
|
||||
void BuildResult2(PolyTree &polytree);
|
||||
void SetHoleState(TEdge *e, OutRec *outrec);
|
||||
void DisposeIntersectNodes();
|
||||
bool FixupIntersectionOrder();
|
||||
void FixupOutPolygon(OutRec &outrec);
|
||||
void FixupOutPolyline(OutRec &outrec);
|
||||
bool IsHole(TEdge *e);
|
||||
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
|
||||
void FixHoleLinkage(OutRec &outrec);
|
||||
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
|
||||
void ClearJoins();
|
||||
void ClearGhostJoins();
|
||||
void AddGhostJoin(OutPt *op, const IntPoint offPt);
|
||||
bool JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2);
|
||||
void JoinCommonEdges();
|
||||
void DoSimplePolygons();
|
||||
void FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec);
|
||||
void FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec);
|
||||
void FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec);
|
||||
#ifdef use_xyz
|
||||
void SetZ(IntPoint &pt, TEdge &e1, TEdge &e2);
|
||||
#endif
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class ClipperOffset {
|
||||
public:
|
||||
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
|
||||
~ClipperOffset();
|
||||
void AddPath(const Path &path, JoinType joinType, EndType endType);
|
||||
void AddPaths(const Paths &paths, JoinType joinType, EndType endType);
|
||||
void Execute(Paths &solution, double delta);
|
||||
void Execute(PolyTree &solution, double delta);
|
||||
void Clear();
|
||||
double MiterLimit;
|
||||
double ArcTolerance;
|
||||
|
||||
private:
|
||||
Paths m_destPolys;
|
||||
Path m_srcPoly;
|
||||
Path m_destPoly;
|
||||
std::vector<DoublePoint> m_normals;
|
||||
double m_delta, m_sinA, m_sin, m_cos;
|
||||
double m_miterLim, m_StepsPerRad;
|
||||
IntPoint m_lowest;
|
||||
PolyNode m_polyNodes;
|
||||
|
||||
void FixOrientations();
|
||||
void DoOffset(double delta);
|
||||
void OffsetPoint(int j, int &k, JoinType jointype);
|
||||
void DoSquare(int j, int k);
|
||||
void DoMiter(int j, int k, double r);
|
||||
void DoRound(int j, int k);
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class clipperException : public std::exception {
|
||||
public:
|
||||
clipperException(const char *description) : m_descr(description) {}
|
||||
virtual ~clipperException() throw() {}
|
||||
virtual const char *what() const throw() { return m_descr.c_str(); }
|
||||
|
||||
private:
|
||||
std::string m_descr;
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
} // ClipperLib namespace
|
||||
|
||||
#endif // clipper_hpp
|
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
|
||||
const std::array<int, 8>& box) {
|
||||
cv::Mat image;
|
||||
srcimage.copyTo(image);
|
||||
|
||||
std::vector<std::vector<int>> points;
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
std::vector<int> tmp;
|
||||
tmp.push_back(box[2 * i]);
|
||||
tmp.push_back(box[2 * i + 1]);
|
||||
points.push_back(tmp);
|
||||
}
|
||||
// box转points
|
||||
int x_collect[4] = {box[0], box[2], box[4], box[6]};
|
||||
int y_collect[4] = {box[1], box[3], box[5], box[7]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
|
||||
//得到rect矩形
|
||||
cv::Mat img_crop;
|
||||
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
|
||||
|
||||
for (int i = 0; i < points.size(); i++) {
|
||||
points[i][0] -= left;
|
||||
points[i][1] -= top;
|
||||
}
|
||||
|
||||
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
|
||||
pow(points[0][1] - points[1][1], 2)));
|
||||
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
|
||||
pow(points[0][1] - points[3][1], 2)));
|
||||
|
||||
cv::Point2f pts_std[4];
|
||||
pts_std[0] = cv::Point2f(0., 0.);
|
||||
pts_std[1] = cv::Point2f(img_crop_width, 0.);
|
||||
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
|
||||
pts_std[3] = cv::Point2f(0.f, img_crop_height);
|
||||
|
||||
cv::Point2f pointsf[4];
|
||||
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
|
||||
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
|
||||
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
|
||||
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
|
||||
|
||||
//透视变换矩阵
|
||||
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
|
||||
|
||||
cv::Mat dst_img;
|
||||
cv::warpPerspective(img_crop, dst_img, M,
|
||||
cv::Size(img_crop_width, img_crop_height),
|
||||
cv::BORDER_REPLICATE);
|
||||
//完成透视变换
|
||||
|
||||
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
|
||||
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
|
||||
cv::transpose(dst_img, srcCopy);
|
||||
cv::flip(srcCopy, srcCopy, 0);
|
||||
return srcCopy;
|
||||
} else {
|
||||
return dst_img;
|
||||
}
|
||||
}
|
||||
|
||||
} // namesoace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
365
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc
Normal file
365
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc
Normal file
@@ -0,0 +1,365 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ocr_postprocess_op.h"
|
||||
#include <map>
|
||||
#include "clipper.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
//获取轮廓区域
|
||||
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance) {
|
||||
int pts_num = 4;
|
||||
float area = 0.0f;
|
||||
float dist = 0.0f;
|
||||
for (int i = 0; i < pts_num; i++) {
|
||||
area += box[i][0] * box[(i + 1) % pts_num][1] -
|
||||
box[i][1] * box[(i + 1) % pts_num][0];
|
||||
dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
|
||||
(box[i][0] - box[(i + 1) % pts_num][0]) +
|
||||
(box[i][1] - box[(i + 1) % pts_num][1]) *
|
||||
(box[i][1] - box[(i + 1) % pts_num][1]));
|
||||
}
|
||||
area = fabs(float(area / 2.0));
|
||||
|
||||
distance = area * unclip_ratio / dist;
|
||||
}
|
||||
|
||||
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
|
||||
const float &unclip_ratio) {
|
||||
float distance = 1.0;
|
||||
|
||||
GetContourArea(box, unclip_ratio, distance);
|
||||
|
||||
ClipperLib::ClipperOffset offset;
|
||||
ClipperLib::Path p;
|
||||
p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
|
||||
<< ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
|
||||
<< ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
|
||||
<< ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
|
||||
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
|
||||
|
||||
ClipperLib::Paths soln;
|
||||
offset.Execute(soln, distance);
|
||||
std::vector<cv::Point2f> points;
|
||||
|
||||
for (int j = 0; j < soln.size(); j++) {
|
||||
for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
|
||||
points.emplace_back(soln[j][i].X, soln[j][i].Y);
|
||||
}
|
||||
}
|
||||
cv::RotatedRect res;
|
||||
if (points.size() <= 0) {
|
||||
res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
|
||||
} else {
|
||||
res = cv::minAreaRect(points);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
//将图像的矩阵转换为float类型的array数组返回
|
||||
float **PostProcessor::Mat2Vec(cv::Mat mat) {
|
||||
auto **array = new float *[mat.rows];
|
||||
for (int i = 0; i < mat.rows; ++i) array[i] = new float[mat.cols];
|
||||
for (int i = 0; i < mat.rows; ++i) {
|
||||
for (int j = 0; j < mat.cols; ++j) {
|
||||
array[i][j] = mat.at<float>(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
//对点进行顺时针方向的排序(从左到右,从上到下) (order points
|
||||
// clockwise[顺时针方向])
|
||||
std::vector<std::vector<int>> PostProcessor::OrderPointsClockwise(
|
||||
std::vector<std::vector<int>> pts) {
|
||||
std::vector<std::vector<int>> box = pts;
|
||||
std::sort(box.begin(), box.end(), XsortInt);
|
||||
|
||||
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
|
||||
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
|
||||
|
||||
if (leftmost[0][1] > leftmost[1][1]) std::swap(leftmost[0], leftmost[1]);
|
||||
|
||||
if (rightmost[0][1] > rightmost[1][1]) std::swap(rightmost[0], rightmost[1]);
|
||||
|
||||
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
|
||||
leftmost[1]};
|
||||
return rect;
|
||||
}
|
||||
|
||||
//将图像的矩阵转换为float类型的vector数组返回
|
||||
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
|
||||
std::vector<std::vector<float>> img_vec;
|
||||
std::vector<float> tmp;
|
||||
|
||||
for (int i = 0; i < mat.rows; ++i) {
|
||||
tmp.clear();
|
||||
for (int j = 0; j < mat.cols; ++j) {
|
||||
tmp.push_back(mat.at<float>(i, j));
|
||||
}
|
||||
img_vec.push_back(tmp);
|
||||
}
|
||||
return img_vec;
|
||||
}
|
||||
|
||||
//判断元素为浮点数float的vector的精度,如果a中元素的精度不等于b中元素的精度,则返回false
|
||||
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
|
||||
if (a[0] != b[0]) return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
|
||||
if (a[0] != b[0]) return a[0] < b[0];
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
||||
float &ssid) {
|
||||
ssid = std::max(box.size.width, box.size.height);
|
||||
|
||||
cv::Mat points;
|
||||
cv::boxPoints(box, points);
|
||||
|
||||
auto array = Mat2Vector(points);
|
||||
std::sort(array.begin(), array.end(), XsortFp32);
|
||||
|
||||
std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
|
||||
idx4 = array[3];
|
||||
if (array[3][1] <= array[2][1]) {
|
||||
idx2 = array[3];
|
||||
idx3 = array[2];
|
||||
} else {
|
||||
idx2 = array[2];
|
||||
idx3 = array[3];
|
||||
}
|
||||
if (array[1][1] <= array[0][1]) {
|
||||
idx1 = array[1];
|
||||
idx4 = array[0];
|
||||
} else {
|
||||
idx1 = array[0];
|
||||
idx4 = array[1];
|
||||
}
|
||||
|
||||
array[0] = idx1;
|
||||
array[1] = idx2;
|
||||
array[2] = idx3;
|
||||
array[3] = idx4;
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
cv::Mat pred) {
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
std::vector<float> box_x;
|
||||
std::vector<float> box_y;
|
||||
for (int i = 0; i < contour.size(); ++i) {
|
||||
box_x.push_back(contour[i].x);
|
||||
box_y.push_back(contour[i].y);
|
||||
}
|
||||
|
||||
int xmin =
|
||||
clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
|
||||
width - 1);
|
||||
int xmax =
|
||||
clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
|
||||
width - 1);
|
||||
int ymin =
|
||||
clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
|
||||
height - 1);
|
||||
int ymax =
|
||||
clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
|
||||
height - 1);
|
||||
|
||||
cv::Mat mask;
|
||||
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
|
||||
|
||||
cv::Point *rook_point = new cv::Point[contour.size()];
|
||||
|
||||
for (int i = 0; i < contour.size(); ++i) {
|
||||
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
|
||||
}
|
||||
const cv::Point *ppt[1] = {rook_point};
|
||||
int npt[] = {int(contour.size())};
|
||||
|
||||
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
|
||||
|
||||
cv::Mat croppedImg;
|
||||
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
|
||||
.copyTo(croppedImg);
|
||||
float score = cv::mean(croppedImg, mask)[0];
|
||||
|
||||
delete[] rook_point;
|
||||
return score;
|
||||
}
|
||||
|
||||
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
cv::Mat pred) {
|
||||
auto array = box_array;
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
|
||||
float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
|
||||
float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};
|
||||
|
||||
int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
|
||||
width - 1);
|
||||
int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
|
||||
width - 1);
|
||||
int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
|
||||
height - 1);
|
||||
int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
|
||||
height - 1);
|
||||
|
||||
cv::Mat mask;
|
||||
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
|
||||
|
||||
cv::Point root_point[4];
|
||||
root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
|
||||
root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
|
||||
root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
|
||||
root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
|
||||
const cv::Point *ppt[1] = {root_point};
|
||||
int npt[] = {4};
|
||||
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
|
||||
|
||||
cv::Mat croppedImg;
|
||||
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
|
||||
.copyTo(croppedImg);
|
||||
|
||||
auto score = cv::mean(croppedImg, mask)[0];
|
||||
return score;
|
||||
}
|
||||
|
||||
//这个应该是DB(差分二值化)相关的内容,方法从 Bitmap 图中获取检测框
|
||||
//涉及到box_thresh(低于这个阈值的boxs不予显示)和det_db_unclip_ratio(文本框扩张的系数,关系到文本框的大小)
|
||||
void PostProcessor::BoxesFromBitmap(
|
||||
const cv::Mat pred, std::vector<std::vector<std::vector<int>>> *boxes,
|
||||
const cv::Mat bitmap, const float &box_thresh,
|
||||
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
|
||||
const int min_size = 3;
|
||||
const int max_candidates = 1000;
|
||||
|
||||
int width = bitmap.cols;
|
||||
int height = bitmap.rows;
|
||||
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
std::vector<cv::Vec4i> hierarchy;
|
||||
|
||||
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
|
||||
cv::CHAIN_APPROX_SIMPLE);
|
||||
|
||||
int num_contours =
|
||||
contours.size() >= max_candidates ? max_candidates : contours.size();
|
||||
|
||||
for (int _i = 0; _i < num_contours; _i++) {
|
||||
if (contours[_i].size() <= 2) {
|
||||
continue;
|
||||
}
|
||||
float ssid;
|
||||
cv::RotatedRect box = cv::minAreaRect(contours[_i]);
|
||||
auto array = GetMiniBoxes(box, ssid);
|
||||
|
||||
auto box_for_unclip = array;
|
||||
// end get_mini_box
|
||||
|
||||
if (ssid < min_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float score;
|
||||
if (det_db_score_mode == "slow") /* compute using polygon*/
|
||||
score = PolygonScoreAcc(contours[_i], pred);
|
||||
else
|
||||
score = BoxScoreFast(array, pred);
|
||||
|
||||
if (score < box_thresh) continue;
|
||||
|
||||
// start for unclip
|
||||
cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio);
|
||||
if (points.size.height < 1.001 && points.size.width < 1.001) {
|
||||
continue;
|
||||
}
|
||||
// end for unclip
|
||||
|
||||
cv::RotatedRect clipbox = points;
|
||||
auto cliparray = GetMiniBoxes(clipbox, ssid);
|
||||
|
||||
if (ssid < min_size + 2) continue;
|
||||
|
||||
int dest_width = pred.cols;
|
||||
int dest_height = pred.rows;
|
||||
std::vector<std::vector<int>> intcliparray;
|
||||
|
||||
for (int num_pt = 0; num_pt < 4; num_pt++) {
|
||||
std::vector<int> a{
|
||||
int(clampf(
|
||||
roundf(cliparray[num_pt][0] / float(width) * float(dest_width)),
|
||||
0, float(dest_width))),
|
||||
int(clampf(
|
||||
roundf(cliparray[num_pt][1] / float(height) * float(dest_height)),
|
||||
0, float(dest_height)))};
|
||||
intcliparray.push_back(a);
|
||||
}
|
||||
boxes->push_back(intcliparray);
|
||||
|
||||
} // end for
|
||||
// return true;
|
||||
}
|
||||
|
||||
//方法根据识别结果获取目标框位置
|
||||
void PostProcessor::FilterTagDetRes(
|
||||
std::vector<std::vector<std::vector<int>>> *boxes, float ratio_h,
|
||||
float ratio_w, const std::map<std::string, std::array<float, 2>> &im_info) {
|
||||
int oriimg_h = im_info.at("input_shape")[0];
|
||||
int oriimg_w = im_info.at("input_shape")[1];
|
||||
|
||||
for (int n = 0; n < boxes->size(); n++) {
|
||||
(*boxes)[n] = OrderPointsClockwise((*boxes)[n]);
|
||||
for (int m = 0; m < (*boxes)[0].size(); m++) {
|
||||
(*boxes)[n][m][0] /= ratio_w;
|
||||
(*boxes)[n][m][1] /= ratio_h;
|
||||
|
||||
(*boxes)[n][m][0] = int(_min(_max((*boxes)[n][m][0], 0), oriimg_w - 1));
|
||||
(*boxes)[n][m][1] = int(_min(_max((*boxes)[n][m][1], 0), oriimg_h - 1));
|
||||
}
|
||||
}
|
||||
|
||||
//此时已经拿到所有的点. 再进行下面的筛选
|
||||
for (int n = (*boxes).size() - 1; n >= 0; n--) {
|
||||
int rect_width, rect_height;
|
||||
rect_width = int(sqrt(pow((*boxes)[n][0][0] - (*boxes)[n][1][0], 2) +
|
||||
pow((*boxes)[n][0][1] - (*boxes)[n][1][1], 2)));
|
||||
rect_height = int(sqrt(pow((*boxes)[n][0][0] - (*boxes)[n][3][0], 2) +
|
||||
pow((*boxes)[n][0][1] - (*boxes)[n][3][1], 2)));
|
||||
|
||||
//原始实现,小于4的跳过,只return大于4的
|
||||
// if (rect_width <= 4 || rect_height <= 4) continue;
|
||||
// root_points.push_back((*boxes)[n]);
|
||||
|
||||
//小于4的删除掉. erase配合逆序遍历.
|
||||
if (rect_width <= 4 || rect_height <= 4) {
|
||||
boxes->erase(boxes->begin() + n);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
93
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h
Normal file
93
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/utils/clipper.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
class PostProcessor {
|
||||
public:
|
||||
void GetContourArea(const std::vector<std::vector<float>> &box,
|
||||
float unclip_ratio, float &distance);
|
||||
|
||||
cv::RotatedRect UnClip(std::vector<std::vector<float>> box,
|
||||
const float &unclip_ratio);
|
||||
|
||||
float **Mat2Vec(cv::Mat mat);
|
||||
|
||||
std::vector<std::vector<int>> OrderPointsClockwise(
|
||||
std::vector<std::vector<int>> pts);
|
||||
|
||||
std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box,
|
||||
float &ssid);
|
||||
|
||||
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
|
||||
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
|
||||
|
||||
void BoxesFromBitmap(const cv::Mat pred,
|
||||
std::vector<std::vector<std::vector<int>>> *boxes,
|
||||
const cv::Mat bitmap, const float &box_thresh,
|
||||
const float &det_db_unclip_ratio,
|
||||
const std::string &det_db_score_mode);
|
||||
|
||||
void FilterTagDetRes(
|
||||
std::vector<std::vector<std::vector<int>>> *boxes, float ratio_h,
|
||||
float ratio_w,
|
||||
const std::map<std::string, std::array<float, 2>> &im_info);
|
||||
|
||||
private:
|
||||
static bool XsortInt(std::vector<int> a, std::vector<int> b);
|
||||
|
||||
static bool XsortFp32(std::vector<float> a, std::vector<float> b);
|
||||
|
||||
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
|
||||
|
||||
inline int _max(int a, int b) { return a >= b ? a : b; }
|
||||
|
||||
inline int _min(int a, int b) { return a >= b ? b : a; }
|
||||
|
||||
template <class T>
|
||||
inline T clamp(T x, T min, T max) {
|
||||
if (x > max) return max;
|
||||
if (x < min) return min;
|
||||
return x;
|
||||
}
|
||||
|
||||
inline float clampf(float x, float min, float max) {
|
||||
if (x > max) return max;
|
||||
if (x < min) return min;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
35
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h
Normal file
35
csrc/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace ocr {
|
||||
|
||||
cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
|
||||
const std::array<int, 8> &box);
|
||||
} // namespace ocr
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -22,6 +22,7 @@ void BindSegmentation(pybind11::module& m);
|
||||
void BindMatting(pybind11::module& m);
|
||||
void BindFaceDet(pybind11::module& m);
|
||||
void BindFaceId(pybind11::module& m);
|
||||
void BindOcr(pybind11::module& m);
|
||||
#ifdef ENABLE_VISION_VISUALIZE
|
||||
void BindVisualize(pybind11::module& m);
|
||||
#endif
|
||||
@@ -42,6 +43,15 @@ void BindVision(pybind11::module& m) {
|
||||
.def("__repr__", &vision::DetectionResult::Str)
|
||||
.def("__str__", &vision::DetectionResult::Str);
|
||||
|
||||
pybind11::class_<vision::OCRResult>(m, "OCRResult")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("boxes", &vision::OCRResult::boxes)
|
||||
.def_readwrite("text", &vision::OCRResult::text)
|
||||
.def_readwrite("score", &vision::OCRResult::rec_scores)
|
||||
.def_readwrite("cls_score", &vision::OCRResult::cls_scores)
|
||||
.def_readwrite("cls_label", &vision::OCRResult::cls_label)
|
||||
.def("__repr__", &vision::OCRResult::Str)
|
||||
.def("__str__", &vision::OCRResult::Str);
|
||||
pybind11::class_<vision::FaceDetectionResult>(m, "FaceDetectionResult")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("boxes", &vision::FaceDetectionResult::boxes)
|
||||
@@ -81,6 +91,7 @@ void BindVision(pybind11::module& m) {
|
||||
BindFaceDet(m);
|
||||
BindFaceId(m);
|
||||
BindMatting(m);
|
||||
BindOcr(m);
|
||||
#ifdef ENABLE_VISION_VISUALIZE
|
||||
BindVisualize(m);
|
||||
#endif
|
||||
|
46
csrc/fastdeploy/vision/visualize/ocr.cc
Normal file
46
csrc/fastdeploy/vision/visualize/ocr.cc
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef ENABLE_VISION_VISUALIZE
|
||||
|
||||
#include "fastdeploy/vision/visualize/visualize.h"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
cv::Mat Visualize::VisOcr(const cv::Mat &im, const OCRResult &ocr_result) {
|
||||
auto vis_im = im.clone();
|
||||
|
||||
for (int n = 0; n < ocr_result.boxes.size(); n++) {
|
||||
//遍历每一个盒子
|
||||
cv::Point rook_points[4];
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
//对每一个盒子 array<float,8>
|
||||
rook_points[m] = cv::Point(int(ocr_result.boxes[n][m * 2]),
|
||||
int(ocr_result.boxes[n][m * 2 + 1]));
|
||||
}
|
||||
|
||||
const cv::Point *ppt[1] = {rook_points};
|
||||
int npt[] = {4};
|
||||
cv::polylines(vis_im, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
|
||||
}
|
||||
|
||||
return vis_im;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
#endif
|
@@ -35,6 +35,7 @@ class FASTDEPLOY_DECL Visualize {
|
||||
const SegmentationResult& result);
|
||||
static cv::Mat VisMattingAlpha(const cv::Mat& im, const MattingResult& result,
|
||||
bool remove_small_connected_area = false);
|
||||
static cv::Mat VisOcr(const cv::Mat& srcimg, const OCRResult& ocr_result);
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -48,6 +48,14 @@ void BindVisualize(pybind11::module& m) {
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
})
|
||||
.def_static("vis_ppocr",
|
||||
[](pybind11::array& im_data, vision::OCRResult& result) {
|
||||
auto im = PyArrayToCvMat(im_data);
|
||||
auto vis_im = vision::Visualize::VisOcr(im, result);
|
||||
FDTensor out;
|
||||
vision::Mat(vis_im).ShareWithTensor(&out);
|
||||
return TensorToPyArray(out);
|
||||
})
|
||||
.def_static("vis_matting_alpha",
|
||||
[](pybind11::array& im_data, vision::MattingResult& result,
|
||||
bool remove_small_connected_area) {
|
||||
|
14
examples/vision/ocr/PPOCRSystemv2/cpp/CMakeLists.txt
Normal file
14
examples/vision/ocr/PPOCRSystemv2/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
PROJECT(infer_demo C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
|
||||
|
||||
# 指定下载解压后的fastdeploy库路径
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
# 添加FastDeploy依赖头文件
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||
# 添加FastDeploy库依赖
|
||||
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
|
139
examples/vision/ocr/PPOCRSystemv2/cpp/README.md
Normal file
139
examples/vision/ocr/PPOCRSystemv2/cpp/README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# PPOCRSystemv2 C++部署示例
|
||||
|
||||
本目录下提供`infer.cc`快速完成PPOCRSystemv2在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start)
|
||||
|
||||
以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试
|
||||
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
wget https://https://bj.bcebos.com/paddlehub/fastdeploy/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
|
||||
tar xvf fastdeploy-linux-x64-0.2.0.tgz
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0
|
||||
make -j
|
||||
|
||||
|
||||
# 下载模型,图片和label文件
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
|
||||
tar xvf ch_PP-OCRv2_det_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar
|
||||
tar xvf ch_PP-OCRv2_rec_infer.tar
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
|
||||
# CPU推理
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
|
||||
# GPU推理
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1
|
||||
# GPU上TensorRT推理
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2
|
||||
# OCR还支持det/cls/rec三个模型的组合使用,例如当我们不想使用cls模型的时候,只需要给cls模型路径的位置,传入一个空的字符串, 例子如下
|
||||
./infer_demo ./ch_PP-OCRv2_det_infer "" ./ch_PP-OCRv2_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
|
||||
|
||||
|
||||
## PPOCRSystemv2 C++接口
|
||||
|
||||
### PPOCRSystemv2类
|
||||
|
||||
```
|
||||
fastdeploy::application::ocrsystem::PPOCRSystemv2(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
|
||||
```
|
||||
|
||||
PPOCRSystemv2 的初始化,由检测,分类和识别模型串联构成
|
||||
|
||||
**参数**
|
||||
|
||||
> * **DBDetector**(model): OCR中的检测模型
|
||||
> * **Classifier**(model): OCR中的分类模型
|
||||
> * **Recognizer**(model): OCR中的识别模型
|
||||
|
||||
#### Predict函数
|
||||
|
||||
> ```
|
||||
> std::vector<std::vector<fastdeploy::vision::OCRResult>> ocr_results =
|
||||
> PPOCRSystemv2.Predict(std::vector<cv::Mat> cv_all_imgs);
|
||||
>
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入一个可装入多张图片的图片列表,后可输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **cv_all_imgs**: 输入图像,注意需为HWC,BGR格式
|
||||
> > * **ocr_results**: OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果, OCRResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
|
||||
## DBDetector C++接口
|
||||
|
||||
### DBDetector类
|
||||
|
||||
```
|
||||
fastdeploy::vision::ocr::DBDetector(const std::string& model_file, const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
```
|
||||
|
||||
DBDetector模型加载和初始化,其中模型为paddle模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为Paddle格式
|
||||
|
||||
### Classifier类与DBDetector类相同
|
||||
|
||||
### Recognizer类
|
||||
```
|
||||
Recognizer(const std::string& model_file,
|
||||
const std::string& params_file = "",
|
||||
const std::string& label_path = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
```
|
||||
Recognizer类初始化时,需要在label_path参数中,输入识别模型所需的label文件,其他参数均与DBDetector类相同
|
||||
|
||||
**参数**
|
||||
> * **label_path**(str): 识别模型的label文件路径
|
||||
|
||||
|
||||
### 类成员变量
|
||||
#### DBDetector预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸,当长边超出这个值时会将长边resize到这个大小,短边等比例缩放,默认为960
|
||||
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值,默认为0.3
|
||||
> > * **det_db_box_thresh**(double): DB模型输出框的阈值,低于此值的预测框会被丢弃,默认为0.6
|
||||
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例,默认为1.5
|
||||
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow,即求polygon区域的平均分数的方式
|
||||
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
|
||||
|
||||
#### Classifier预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值,输入的图片将被翻转,默认为0.9
|
||||
|
||||
|
||||
- [模型介绍](../../)
|
||||
- [Python部署](../python)
|
||||
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
|
295
examples/vision/ocr/PPOCRSystemv2/cpp/infer.cc
Normal file
295
examples/vision/ocr/PPOCRSystemv2/cpp/infer.cc
Normal file
@@ -0,0 +1,295 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
void CpuInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseCpu();
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseCpu();
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseCpu();
|
||||
rec_option.UsePaddleBackend(); // OCRv2的rec模型暂不支持ORT后端
|
||||
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv2_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
//准备模型
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseGpu();
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseGpu();
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseGpu();
|
||||
rec_option
|
||||
.UsePaddleBackend(); // OCRv2的rec模型暂不支持ORT后端与PaddleInference
|
||||
// v2.3.2
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv2_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void TrtInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
//准备模型
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseGpu();
|
||||
det_option.UseTrtBackend();
|
||||
det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640},
|
||||
{1, 3, 960, 960});
|
||||
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseGpu();
|
||||
cls_option.UseTrtBackend();
|
||||
cls_option.SetTrtInputShape("x", {1, 3, 48, 192});
|
||||
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseGpu();
|
||||
rec_option.UseTrtBackend();
|
||||
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320},
|
||||
{1, 3, 48, 2000});
|
||||
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv2_app = fastdeploy::application::ocrsystem::PPOCRSystemv2(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv2_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 7) {
|
||||
std::cout << "Usage: infer_demo path/to/det_model path/to/cls_model "
|
||||
"path/to/rec_model path/to/rec_label_file path/to/image "
|
||||
"run_option, "
|
||||
"e.g ./infer_demo ./ch_PP-OCRv2_det_infer "
|
||||
"./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv2_rec_infer "
|
||||
"./ppocr_keys_v1.txt ./12.jpg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu; 2: run with gpu and use tensorrt backend."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[6]) == 0) {
|
||||
CpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
} else if (std::atoi(argv[6]) == 1) {
|
||||
GpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
} else if (std::atoi(argv[6]) == 2) {
|
||||
TrtInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
}
|
||||
return 0;
|
||||
}
|
131
examples/vision/ocr/PPOCRSystemv2/python/README.md
Normal file
131
examples/vision/ocr/PPOCRSystemv2/python/README.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# PPOCRSystemv2 Python部署示例
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start)
|
||||
|
||||
本目录下提供`infer.py`快速完成PPOCRSystemv2在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
|
||||
|
||||
```
|
||||
|
||||
# 下载模型,图片和label文件
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
|
||||
tar xvf ch_PP-OCRv2_det_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar
|
||||
tar xvf ch_PP-OCRv2_rec_infer.tar
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
|
||||
#下载部署示例代码
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd examples/vison/ocr/PPOCRSystemv2/python/
|
||||
|
||||
# CPU推理
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
|
||||
# GPU推理
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu
|
||||
# GPU上使用TensorRT推理
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --det_use_trt True --cls_use_trt True --rec_use_trt True
|
||||
# OCR还支持det/cls/rec三个模型的组合使用,例如当我们不想使用cls模型的时候,只需要给--cls_model传入一个空的字符串, 例子如下:
|
||||
python infer.py --det_model ch_PP-OCRv2_det_infer --cls_model "" --rec_model ch_PP-OCRv2_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
|
||||
|
||||
## PPOCRSystemv2 Python接口
|
||||
|
||||
```
|
||||
fastdeploy.vision.ocr.PPOCRSystemv2(ocr_det = det_model._model, ocr_cls = cls_model._model, ocr_rec = rec_model._model)
|
||||
```
|
||||
|
||||
PPOCRSystemv2的初始化,输入的参数是检测模型,分类模型和识别模型
|
||||
|
||||
**参数**
|
||||
|
||||
> * **ocr_det**(model): OCR中的检测模型
|
||||
> * **ocr_cls**(model): OCR中的分类模型
|
||||
> * **ocr_rec**(model): OCR中的识别模型
|
||||
|
||||
### predict函数
|
||||
|
||||
> ```
|
||||
> result = PPOCRSystemv2.predict(img_list)
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入的是一个可包含多个图像的list
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **img_list**(list[np.ndarray]): 输入数据的list,每张图片注意需为HWC,BGR格式
|
||||
> > * **result**(float): OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果,
|
||||
|
||||
> **返回**
|
||||
>
|
||||
> > 返回`fastdeploy.vision.OCRResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
|
||||
|
||||
## DBDetector Python接口
|
||||
|
||||
### DBDetector类
|
||||
|
||||
```
|
||||
fastdeploy.vision.ocr.DBDetector(model_file, params_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
```
|
||||
|
||||
DBDetector模型加载和初始化,其中模型为paddle模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为PADDLE格式
|
||||
|
||||
### Classifier类与DBDetector类相同
|
||||
|
||||
### Recognizer类
|
||||
```
|
||||
fastdeploy.vision.ocr.Recognizer(rec_model_file,rec_params_file,rec_label_file,
|
||||
runtime_option=rec_runtime_option,model_format=Frontend.PADDLE)
|
||||
```
|
||||
Recognizer类初始化时,需要在rec_label_file参数中,输入识别模型所需的label文件路径,其他参数均与DBDetector类相同
|
||||
|
||||
**参数**
|
||||
> * **label_path**(str): 识别模型的label文件路径
|
||||
|
||||
|
||||
|
||||
### 类成员变量
|
||||
|
||||
#### DBDetector预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸,当长边超出这个值时会将长边resize到这个大小,短边等比例缩放,默认为960
|
||||
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值,默认为0.3
|
||||
> > * **det_db_box_thresh**(double): DB模型输出框的阈值,低于此值的预测框会被丢弃,默认为0.6
|
||||
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例,默认为1.5
|
||||
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow,即求polygon区域的平均分数的方式
|
||||
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
|
||||
|
||||
#### Classifier预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值,输入的图片将被翻转,默认为0.9
|
||||
|
||||
|
||||
|
||||
## 其它文档
|
||||
|
||||
- [YOLOv5 模型介绍](..)
|
||||
- [YOLOv5 C++部署](../cpp)
|
||||
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
146
examples/vision/ocr/PPOCRSystemv2/python/infer.py
Normal file
146
examples/vision/ocr/PPOCRSystemv2/python/infer.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--det_model", required=True, help="Path of Detection model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--cls_model",
|
||||
required=True,
|
||||
help="Path of Classification model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--rec_model",
|
||||
required=True,
|
||||
help="Path of Recognization model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--rec_label_file",
|
||||
required=True,
|
||||
help="Path of Recognization model of PPOCR.")
|
||||
|
||||
parser.add_argument(
|
||||
"--image", type=str, required=True, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--det_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
parser.add_argument(
|
||||
"--cls_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
parser.add_argument(
|
||||
"--rec_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_det_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.det_use_trt:
|
||||
option.use_trt_backend()
|
||||
#det_max_side_len 默认为960,当用户更改DET模型的max_side_len参数时,请将此参数同时更改
|
||||
det_max_side_len = 960
|
||||
option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640],
|
||||
[1, 3, det_max_side_len, det_max_side_len])
|
||||
|
||||
return option
|
||||
|
||||
|
||||
def build_cls_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
option.use_paddle_backend()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.cls_use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("x", [1, 3, 32, 100])
|
||||
|
||||
return option
|
||||
|
||||
|
||||
def build_rec_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
option.use_paddle_backend()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.rec_use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
|
||||
[1, 3, 48, 2000])
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
#Det模型
|
||||
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
|
||||
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
|
||||
#Cls模型
|
||||
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
|
||||
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
|
||||
#Rec模型
|
||||
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
|
||||
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
|
||||
rec_label_file = args.rec_label_file
|
||||
|
||||
#默认
|
||||
det_model = fd.vision.ocr.DBDetector("")
|
||||
cls_model = fd.vision.ocr.Classifier()
|
||||
rec_model = fd.vision.ocr.Recognizer()
|
||||
|
||||
#模型初始化
|
||||
if (len(args.det_model) != 0):
|
||||
det_runtime_option = build_det_option(args)
|
||||
det_model = fd.vision.ocr.DBDetector(
|
||||
det_model_file, det_params_file, runtime_option=det_runtime_option)
|
||||
|
||||
if (len(args.cls_model) != 0):
|
||||
cls_runtime_option = build_cls_option(args)
|
||||
cls_model = fd.vision.ocr.Classifier(
|
||||
cls_model_file, cls_params_file, runtime_option=cls_runtime_option)
|
||||
|
||||
if (len(args.rec_model) != 0):
|
||||
rec_runtime_option = build_rec_option(args)
|
||||
rec_model = fd.vision.ocr.Recognizer(
|
||||
rec_model_file,
|
||||
rec_params_file,
|
||||
rec_label_file,
|
||||
runtime_option=rec_runtime_option)
|
||||
|
||||
ppocrsysv2 = fd.vision.ocr.PPOCRSystemv2(
|
||||
ocr_det=det_model._model,
|
||||
ocr_cls=cls_model._model,
|
||||
ocr_rec=rec_model._model)
|
||||
|
||||
# 预测图片准备
|
||||
im = cv2.imread(args.image)
|
||||
|
||||
#预测并打印结果
|
||||
result = ppocrsysv2.predict(im)
|
||||
print(result)
|
||||
|
||||
# 可视化结果
|
||||
vis_im = fd.vision.vis_ppocr(im, result)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
14
examples/vision/ocr/PPOCRSystemv3/cpp/CMakeLists.txt
Normal file
14
examples/vision/ocr/PPOCRSystemv3/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
PROJECT(infer_demo C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
|
||||
|
||||
# 指定下载解压后的fastdeploy库路径
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
# 添加FastDeploy依赖头文件
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||
# 添加FastDeploy库依赖
|
||||
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
|
139
examples/vision/ocr/PPOCRSystemv3/cpp/README.md
Normal file
139
examples/vision/ocr/PPOCRSystemv3/cpp/README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# PPOCRSystemv3 C++部署示例
|
||||
|
||||
本目录下提供`infer.cc`快速完成PPOCRSystemv3在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start)
|
||||
|
||||
以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试
|
||||
|
||||
```
|
||||
mkdir build
|
||||
cd build
|
||||
wget https://https://bj.bcebos.com/paddlehub/fastdeploy/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
|
||||
tar xvf fastdeploy-linux-x64-0.2.0.tgz
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0
|
||||
make -j
|
||||
|
||||
|
||||
# 下载模型,图片和label文件
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
|
||||
tar xvf ch_PP-OCRv3_det_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
|
||||
tar xvf ch_PP-OCRv3_rec_infer.tar
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
|
||||
# CPU推理
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
|
||||
# GPU推理
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1
|
||||
# GPU上TensorRT推理
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2
|
||||
# OCR还支持det/cls/rec三个模型的组合使用,例如当我们不想使用cls模型的时候,只需要给cls模型路径的位置,传入一个空的字符串, 例子如下
|
||||
./infer_demo ./ch_PP-OCRv3_det_infer "" ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
|
||||
|
||||
|
||||
## PPOCRSystemv3 C++接口
|
||||
|
||||
### PPOCRSystemv3类
|
||||
|
||||
```
|
||||
fastdeploy::application::ocrsystem::PPOCRSystemv3(fastdeploy::vision::ocr::DBDetector* ocr_det = nullptr,
|
||||
fastdeploy::vision::ocr::Classifier* ocr_cls = nullptr,
|
||||
fastdeploy::vision::ocr::Recognizer* ocr_rec = nullptr);
|
||||
```
|
||||
|
||||
PPOCRSystemv3 的初始化,由检测,分类和识别模型串联构成
|
||||
|
||||
**参数**
|
||||
|
||||
> * **DBDetector**(model): OCR中的检测模型
|
||||
> * **Classifier**(model): OCR中的分类模型
|
||||
> * **Recognizer**(model): OCR中的识别模型
|
||||
|
||||
#### Predict函数
|
||||
|
||||
> ```
|
||||
> std::vector<std::vector<fastdeploy::vision::OCRResult>> ocr_results =
|
||||
> PPOCRSystemv3.Predict(std::vector<cv::Mat> cv_all_imgs);
|
||||
>
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入一个可装入多张图片的图片列表,后可输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **cv_all_imgs**: 输入图像,注意需为HWC,BGR格式
|
||||
> > * **ocr_results**: OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果, OCRResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
|
||||
## DBDetector C++接口
|
||||
|
||||
### DBDetector类
|
||||
|
||||
```
|
||||
fastdeploy::vision::ocr::DBDetector(const std::string& model_file, const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
```
|
||||
|
||||
DBDetector模型加载和初始化,其中模型为paddle模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为Paddle格式
|
||||
|
||||
### Classifier类与DBDetector类相同
|
||||
|
||||
### Recognizer类
|
||||
```
|
||||
Recognizer(const std::string& model_file,
|
||||
const std::string& params_file = "",
|
||||
const std::string& label_path = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::PADDLE);
|
||||
```
|
||||
Recognizer类初始化时,需要在label_path参数中,输入识别模型所需的label文件,其他参数均与DBDetector类相同
|
||||
|
||||
**参数**
|
||||
> * **label_path**(str): 识别模型的label文件路径
|
||||
|
||||
|
||||
### 类成员变量
|
||||
#### DBDetector预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸,当长边超出这个值时会将长边resize到这个大小,短边等比例缩放,默认为960
|
||||
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值,默认为0.3
|
||||
> > * **det_db_box_thresh**(double): DB模型输出框的阈值,低于此值的预测框会被丢弃,默认为0.6
|
||||
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例,默认为1.5
|
||||
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow,即求polygon区域的平均分数的方式
|
||||
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
|
||||
|
||||
#### Classifier预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值,输入的图片将被翻转,默认为0.9
|
||||
|
||||
|
||||
- [模型介绍](../../)
|
||||
- [Python部署](../python)
|
||||
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
|
290
examples/vision/ocr/PPOCRSystemv3/cpp/infer.cc
Normal file
290
examples/vision/ocr/PPOCRSystemv3/cpp/infer.cc
Normal file
@@ -0,0 +1,290 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
void CpuInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseCpu();
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseCpu();
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseCpu();
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv3_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
//准备模型
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseGpu();
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseGpu();
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseGpu();
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv3_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void TrtInfer(const std::string& det_model_dir,
|
||||
const std::string& cls_model_dir,
|
||||
const std::string& rec_model_dir,
|
||||
const std::string& rec_label_file,
|
||||
const std::string& image_file) {
|
||||
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
|
||||
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
|
||||
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
|
||||
|
||||
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
|
||||
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
|
||||
auto rec_label = rec_label_file;
|
||||
|
||||
fastdeploy::vision::ocr::DBDetector det_model;
|
||||
fastdeploy::vision::ocr::Classifier cls_model;
|
||||
fastdeploy::vision::ocr::Recognizer rec_model;
|
||||
|
||||
//准备模型
|
||||
if (!det_model_dir.empty()) {
|
||||
auto det_option = fastdeploy::RuntimeOption();
|
||||
det_option.UseGpu();
|
||||
det_option.UseTrtBackend();
|
||||
det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640},
|
||||
{1, 3, 960, 960});
|
||||
|
||||
det_model = fastdeploy::vision::ocr::DBDetector(
|
||||
det_model_file, det_params_file, det_option);
|
||||
|
||||
if (!det_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize det_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cls_model_dir.empty()) {
|
||||
auto cls_option = fastdeploy::RuntimeOption();
|
||||
cls_option.UseGpu();
|
||||
cls_option.UseTrtBackend();
|
||||
cls_option.SetTrtInputShape("x", {1, 3, 48, 192});
|
||||
|
||||
cls_model = fastdeploy::vision::ocr::Classifier(
|
||||
cls_model_file, cls_params_file, cls_option);
|
||||
|
||||
if (!cls_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize cls_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rec_model_dir.empty()) {
|
||||
auto rec_option = fastdeploy::RuntimeOption();
|
||||
rec_option.UseGpu();
|
||||
rec_option.UseTrtBackend();
|
||||
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320},
|
||||
{1, 3, 48, 2000});
|
||||
|
||||
rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||
rec_model_file, rec_params_file, rec_label, rec_option);
|
||||
|
||||
if (!rec_model.Initialized()) {
|
||||
std::cerr << "Failed to initialize rec_model." << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto ocrv3_app = fastdeploy::application::ocrsystem::PPOCRSystemv3(
|
||||
&det_model, &cls_model, &rec_model);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::OCRResult res;
|
||||
//开始预测
|
||||
if (!ocrv3_app.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
//输出预测信息
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
//可视化
|
||||
auto vis_img = fastdeploy::vision::Visualize::VisOcr(im_bak, res);
|
||||
|
||||
cv::imwrite("vis_result.jpg", vis_img);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 7) {
|
||||
std::cout << "Usage: infer_demo path/to/det_model path/to/cls_model "
|
||||
"path/to/rec_model path/to/rec_label_file path/to/image "
|
||||
"run_option, "
|
||||
"e.g ./infer_demo ./ch_PP-OCRv3_det_infer "
|
||||
"./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer "
|
||||
"./ppocr_keys_v1.txt ./12.jpg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu; 2: run with gpu and use tensorrt backend."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[6]) == 0) {
|
||||
CpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
} else if (std::atoi(argv[6]) == 1) {
|
||||
GpuInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
} else if (std::atoi(argv[6]) == 2) {
|
||||
TrtInfer(argv[1], argv[2], argv[3], argv[4], argv[5]);
|
||||
}
|
||||
return 0;
|
||||
}
|
131
examples/vision/ocr/PPOCRSystemv3/python/README.md
Normal file
131
examples/vision/ocr/PPOCRSystemv3/python/README.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# PPOCRSystemv3 Python部署示例
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start)
|
||||
|
||||
本目录下提供`infer.py`快速完成PPOCRSystemv3在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
|
||||
|
||||
```
|
||||
|
||||
# 下载模型,图片和label文件
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
|
||||
tar xvf ch_PP-OCRv3_det_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
|
||||
tar xvf ch_PP-OCRv3_rec_infer.tar
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/doc/imgs/12.jpg
|
||||
|
||||
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.5/ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
|
||||
#下载部署示例代码
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd examples/vison/ocr/PPOCRSystemv3/python/
|
||||
|
||||
# CPU推理
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
|
||||
# GPU推理
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu
|
||||
# GPU上使用TensorRT推理
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device gpu --det_use_trt True --cls_use_trt True --rec_use_trt True
|
||||
# OCR还支持det/cls/rec三个模型的组合使用,例如当我们不想使用cls模型的时候,只需要给--cls_model传入一个空的字符串, 例子如下:
|
||||
python infer.py --det_model ch_PP-OCRv3_det_infer --cls_model "" --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image 12.jpg --device cpu
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
<img width="640" src="https://user-images.githubusercontent.com/109218879/185826024-f7593a0c-1bd2-4a60-b76c-15588484fa08.jpg">
|
||||
|
||||
## PPOCRSystemv3 Python接口
|
||||
|
||||
```
|
||||
fastdeploy.vision.ocr.PPOCRSystemv3(ocr_det = det_model._model, ocr_cls = cls_model._model, ocr_rec = rec_model._model)
|
||||
```
|
||||
|
||||
PPOCRSystemv3的初始化,输入的参数是检测模型,分类模型和识别模型
|
||||
|
||||
**参数**
|
||||
|
||||
> * **ocr_det**(model): OCR中的检测模型
|
||||
> * **ocr_cls**(model): OCR中的分类模型
|
||||
> * **ocr_rec**(model): OCR中的识别模型
|
||||
|
||||
### predict函数
|
||||
|
||||
> ```
|
||||
> result = PPOCRSystemv3.predict(img_list)
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入的是一个可包含多个图像的list
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **img_list**(list[np.ndarray]): 输入数据的list,每张图片注意需为HWC,BGR格式
|
||||
> > * **result**(float): OCR结果,包括由检测模型输出的检测框位置,分类模型输出的方向分类,以及识别模型输出的识别结果,
|
||||
|
||||
> **返回**
|
||||
>
|
||||
> > 返回`fastdeploy.vision.OCRResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
|
||||
|
||||
## DBDetector Python接口
|
||||
|
||||
### DBDetector类
|
||||
|
||||
```
|
||||
fastdeploy.vision.ocr.DBDetector(model_file, params_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||
```
|
||||
|
||||
DBDetector模型加载和初始化,其中模型为paddle模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为PADDLE格式
|
||||
|
||||
### Classifier类与DBDetector类相同
|
||||
|
||||
### Recognizer类
|
||||
```
|
||||
fastdeploy.vision.ocr.Recognizer(rec_model_file,rec_params_file,rec_label_file,
|
||||
runtime_option=rec_runtime_option,model_format=Frontend.PADDLE)
|
||||
```
|
||||
Recognizer类初始化时,需要在rec_label_file参数中,输入识别模型所需的label文件路径,其他参数均与DBDetector类相同
|
||||
|
||||
**参数**
|
||||
> * **label_path**(str): 识别模型的label文件路径
|
||||
|
||||
|
||||
|
||||
### 类成员变量
|
||||
|
||||
#### DBDetector预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **max_side_len**(int): 检测算法前向时图片长边的最大尺寸,当长边超出这个值时会将长边resize到这个大小,短边等比例缩放,默认为960
|
||||
> > * **det_db_thresh**(double): DB模型输出预测图的二值化阈值,默认为0.3
|
||||
> > * **det_db_box_thresh**(double): DB模型输出框的阈值,低于此值的预测框会被丢弃,默认为0.6
|
||||
> > * **det_db_unclip_ratio**(double): DB模型输出框扩大的比例,默认为1.5
|
||||
> > * **det_db_score_mode**(string):DB后处理中计算文本框平均得分的方式,默认为slow,即求polygon区域的平均分数的方式
|
||||
> > * **use_dilation**(bool):是否对检测输出的feature map做膨胀处理,默认为Fasle
|
||||
|
||||
#### Classifier预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **cls_thresh**(double): 当分类模型输出的得分超过此阈值,输入的图片将被翻转,默认为0.9
|
||||
|
||||
|
||||
|
||||
## 其它文档
|
||||
|
||||
- [YOLOv5 模型介绍](..)
|
||||
- [YOLOv5 C++部署](../cpp)
|
||||
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
145
examples/vision/ocr/PPOCRSystemv3/python/infer.py
Normal file
145
examples/vision/ocr/PPOCRSystemv3/python/infer.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--det_model", required=True, help="Path of Detection model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--cls_model",
|
||||
required=True,
|
||||
help="Path of Classification model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--rec_model",
|
||||
required=True,
|
||||
help="Path of Recognization model of PPOCR.")
|
||||
parser.add_argument(
|
||||
"--rec_label_file",
|
||||
required=True,
|
||||
help="Path of Recognization model of PPOCR.")
|
||||
|
||||
parser.add_argument(
|
||||
"--image", type=str, required=True, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--det_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
parser.add_argument(
|
||||
"--cls_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
parser.add_argument(
|
||||
"--rec_use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_det_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.det_use_trt:
|
||||
option.use_trt_backend()
|
||||
#det_max_side_len 默认为960,当用户更改DET模型的max_side_len参数时,请将此参数同时更改
|
||||
det_max_side_len = 960
|
||||
option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640],
|
||||
[1, 3, det_max_side_len, det_max_side_len])
|
||||
|
||||
return option
|
||||
|
||||
|
||||
def build_cls_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
option.use_paddle_backend()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.cls_use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("x", [1, 3, 32, 100])
|
||||
|
||||
return option
|
||||
|
||||
|
||||
def build_rec_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.rec_use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
|
||||
[1, 3, 48, 2000])
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
#Det模型
|
||||
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
|
||||
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
|
||||
#Cls模型
|
||||
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
|
||||
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
|
||||
#Rec模型
|
||||
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
|
||||
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
|
||||
rec_label_file = args.rec_label_file
|
||||
|
||||
#默认
|
||||
det_model = fd.vision.ocr.DBDetector("")
|
||||
cls_model = fd.vision.ocr.Classifier()
|
||||
rec_model = fd.vision.ocr.Recognizer()
|
||||
|
||||
#模型初始化
|
||||
if (len(args.det_model) != 0):
|
||||
det_runtime_option = build_det_option(args)
|
||||
det_model = fd.vision.ocr.DBDetector(
|
||||
det_model_file, det_params_file, runtime_option=det_runtime_option)
|
||||
|
||||
if (len(args.cls_model) != 0):
|
||||
cls_runtime_option = build_cls_option(args)
|
||||
cls_model = fd.vision.ocr.Classifier(
|
||||
cls_model_file, cls_params_file, runtime_option=cls_runtime_option)
|
||||
|
||||
if (len(args.rec_model) != 0):
|
||||
rec_runtime_option = build_rec_option(args)
|
||||
rec_model = fd.vision.ocr.Recognizer(
|
||||
rec_model_file,
|
||||
rec_params_file,
|
||||
rec_label_file,
|
||||
runtime_option=rec_runtime_option)
|
||||
|
||||
ppocrsysv3 = fd.vision.ocr.PPOCRSystemv3(
|
||||
ocr_det=det_model._model,
|
||||
ocr_cls=cls_model._model,
|
||||
ocr_rec=rec_model._model)
|
||||
|
||||
# 预测图片准备
|
||||
im = cv2.imread(args.image)
|
||||
|
||||
#预测并打印结果
|
||||
result = ppocrsysv3.predict(im)
|
||||
print(result)
|
||||
|
||||
# 可视化结果
|
||||
vis_im = fd.vision.vis_ppocr(im, result)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
17
examples/vision/ocr/README.md
Normal file
17
examples/vision/ocr/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# PaddleOCR 模型部署
|
||||
|
||||
## 模型版本说明
|
||||
|
||||
- [PaddleOCR Release/2.5](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5)
|
||||
|
||||
目前FastDeploy支持如下模型的部署
|
||||
|
||||
- [PaddleOCRv3系列模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md)
|
||||
|
||||
## 准备PaddleOCRv3部署模型
|
||||
用户在[PP-OCR系列模型列表](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md)下载相应的的OCRv3系列推理模型即可.
|
||||
|
||||
## 详细部署文档
|
||||
|
||||
- [Python部署](python)
|
||||
- [C++部署](cpp)
|
@@ -20,6 +20,6 @@ from . import segmentation
|
||||
from . import matting
|
||||
from . import facedet
|
||||
from . import faceid
|
||||
|
||||
from . import ocr
|
||||
from . import evaluation
|
||||
from .visualize import *
|
||||
|
20
fastdeploy/vision/ocr/__init__.py
Normal file
20
fastdeploy/vision/ocr/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .ppocr import PPOCRSystemv3
|
||||
from .ppocr import PPOCRSystemv2
|
||||
from .ppocr import DBDetector
|
||||
from .ppocr import Classifier
|
||||
from .ppocr import Recognizer
|
234
fastdeploy/vision/ocr/ppocr/__init__.py
Normal file
234
fastdeploy/vision/ocr/ppocr/__init__.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
# #
|
||||
# # Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# # you may not use this file except in compliance with the License.
|
||||
# # You may obtain a copy of the License at
|
||||
# #
|
||||
# # http://www.apache.org/licenses/LICENSE-2.0
|
||||
# #
|
||||
# # Unless required by applicable law or agreed to in writing, software
|
||||
# # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from .... import FastDeployModel, Frontend
|
||||
from .... import c_lib_wrap as C
|
||||
|
||||
|
||||
class DBDetector(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file="",
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(DBDetector, self).__init__(runtime_option)
|
||||
|
||||
if (len(model_file) == 0):
|
||||
self._model = C.vision.ocr.DBDetector()
|
||||
else:
|
||||
self._model = C.vision.ocr.DBDetector(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "DBDetector initialize failed."
|
||||
|
||||
# 一些跟DBDetector模型有关的属性封装
|
||||
@property
|
||||
def max_side_len(self):
|
||||
return self._model.max_side_len
|
||||
|
||||
@property
|
||||
def det_db_thresh(self):
|
||||
return self._model.det_db_thresh
|
||||
|
||||
@property
|
||||
def det_db_box_thresh(self):
|
||||
return self._model.det_db_box_thresh
|
||||
|
||||
@property
|
||||
def det_db_unclip_ratio(self):
|
||||
return self._model.det_db_unclip_ratio
|
||||
|
||||
@property
|
||||
def det_db_score_mode(self):
|
||||
return self._model.det_db_score_mode
|
||||
|
||||
@property
|
||||
def use_dilation(self):
|
||||
return self._model.use_dilation
|
||||
|
||||
@property
|
||||
def is_scale(self):
|
||||
return self._model.max_wh
|
||||
|
||||
@max_side_len.setter
|
||||
def max_side_len(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `max_side_len` must be type of int."
|
||||
self._model.max_side_len = value
|
||||
|
||||
@det_db_thresh.setter
|
||||
def det_db_thresh(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
float), "The value to set `det_db_thresh` must be type of float."
|
||||
self._model.det_db_thresh = value
|
||||
|
||||
@det_db_box_thresh.setter
|
||||
def det_db_box_thresh(self, value):
|
||||
assert isinstance(
|
||||
value, float
|
||||
), "The value to set `det_db_box_thresh` must be type of float."
|
||||
self._model.det_db_box_thresh = value
|
||||
|
||||
@det_db_unclip_ratio.setter
|
||||
def det_db_unclip_ratio(self, value):
|
||||
assert isinstance(
|
||||
value, float
|
||||
), "The value to set `det_db_unclip_ratio` must be type of float."
|
||||
self._model.det_db_unclip_ratio = value
|
||||
|
||||
@det_db_score_mode.setter
|
||||
def det_db_score_mode(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
str), "The value to set `det_db_score_mode` must be type of str."
|
||||
self._model.det_db_score_mode = value
|
||||
|
||||
@use_dilation.setter
|
||||
def use_dilation(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `use_dilation` must be type of bool."
|
||||
self._model.use_dilation = value
|
||||
|
||||
@is_scale.setter
|
||||
def is_scale(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_scale` must be type of bool."
|
||||
self._model.is_scale = value
|
||||
|
||||
|
||||
class Classifier(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file="",
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(Classifier, self).__init__(runtime_option)
|
||||
|
||||
if (len(model_file) == 0):
|
||||
self._model = C.vision.ocr.Classifier()
|
||||
else:
|
||||
self._model = C.vision.ocr.Classifier(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "Classifier initialize failed."
|
||||
|
||||
@property
|
||||
def cls_thresh(self):
|
||||
return self._model.cls_thresh
|
||||
|
||||
@property
|
||||
def cls_image_shape(self):
|
||||
return self._model.cls_image_shape
|
||||
|
||||
@property
|
||||
def cls_batch_num(self):
|
||||
return self._model.cls_batch_num
|
||||
|
||||
@cls_thresh.setter
|
||||
def cls_thresh(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
float), "The value to set `cls_thresh` must be type of float."
|
||||
self._model.cls_thresh = value
|
||||
|
||||
@cls_image_shape.setter
|
||||
def cls_image_shape(self, value):
|
||||
assert isinstance(
|
||||
value, list), "The value to set `cls_thresh` must be type of list."
|
||||
self._model.cls_image_shape = value
|
||||
|
||||
@cls_batch_num.setter
|
||||
def cls_batch_num(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
int), "The value to set `cls_batch_num` must be type of int."
|
||||
self._model.cls_batch_num = value
|
||||
|
||||
|
||||
class Recognizer(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file="",
|
||||
params_file="",
|
||||
label_path="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.PADDLE):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(Recognizer, self).__init__(runtime_option)
|
||||
|
||||
if (len(model_file) == 0):
|
||||
self._model = C.vision.ocr.Recognizer()
|
||||
else:
|
||||
self._model = C.vision.ocr.Recognizer(
|
||||
model_file, params_file, label_path, self._runtime_option,
|
||||
model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "Recognizer initialize failed."
|
||||
|
||||
@property
|
||||
def rec_img_h(self):
|
||||
return self._model.rec_img_h
|
||||
|
||||
@property
|
||||
def rec_img_w(self):
|
||||
return self._model.rec_img_w
|
||||
|
||||
@property
|
||||
def rec_batch_num(self):
|
||||
return self._model.rec_batch_num
|
||||
|
||||
@rec_img_h.setter
|
||||
def rec_img_h(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `rec_img_h` must be type of int."
|
||||
self._model.rec_img_h = value
|
||||
|
||||
@rec_img_w.setter
|
||||
def rec_img_w(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `rec_img_w` must be type of int."
|
||||
self._model.rec_img_w = value
|
||||
|
||||
@rec_batch_num.setter
|
||||
def rec_batch_num(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
int), "The value to set `rec_batch_num` must be type of int."
|
||||
self._model.rec_batch_num = value
|
||||
|
||||
|
||||
class PPOCRSystemv3(FastDeployModel):
|
||||
def __init__(self, ocr_det=None, ocr_cls=None, ocr_rec=None):
|
||||
|
||||
self._model = C.vision.ocr.PPOCRSystemv3(ocr_det, ocr_cls, ocr_rec)
|
||||
|
||||
def predict(self, input_image):
|
||||
return self._model.predict(input_image)
|
||||
|
||||
|
||||
class PPOCRSystemv2(FastDeployModel):
|
||||
def __init__(self, ocr_det=None, ocr_cls=None, ocr_rec=None):
|
||||
|
||||
self._model = C.vision.ocr.PPOCRSystemv2(ocr_det, ocr_cls, ocr_rec)
|
||||
|
||||
def predict(self, input_image):
|
||||
return self._model.predict(input_image)
|
@@ -40,3 +40,7 @@ def vis_matting_alpha(im_data,
|
||||
remove_small_connected_area=False):
|
||||
return C.vision.Visualize.vis_matting_alpha(im_data, matting_result,
|
||||
remove_small_connected_area)
|
||||
|
||||
|
||||
def vis_ppocr(im_data, det_result):
|
||||
return C.vision.Visualize.vis_ppocr(im_data, det_result)
|
||||
|
Reference in New Issue
Block a user