[Model] change ocr pre and post (#568)

* change ocr pre and post

* add pybind

* change ocr

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix copy bug

* fix code style

* fix bug

* add new function

* fix windows ci bug
This commit is contained in:
Thomas Young
2022-11-18 13:17:42 +08:00
committed by GitHub
parent 1609ce1bab
commit 143506b654
31 changed files with 1402 additions and 569 deletions

View File

@@ -22,6 +22,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <numeric>
#if defined(_WIN32)
#ifdef FASTDEPLOY_LIB

View File

@@ -48,6 +48,7 @@
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"

93
fastdeploy/vision/ocr/ppocr/classifier.cc Normal file → Executable file
View File

@@ -41,16 +41,7 @@ Classifier::Classifier(const std::string& model_file,
initialized = Initialize();
}
// Init
bool Classifier::Initialize() {
// pre&post process parameters
cls_thresh = 0.9;
cls_image_shape = {3, 48, 192};
cls_batch_num = 1;
mean = {0.5f, 0.5f, 0.5f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
@@ -59,85 +50,23 @@ bool Classifier::Initialize() {
return true;
}
void OcrClassifierResizeImage(Mat* mat,
const std::vector<int>& rec_image_shape) {
int imgC = rec_image_shape[0];
int imgH = rec_image_shape[1];
int imgW = rec_image_shape[2];
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
}
bool Classifier::Preprocess(Mat* mat, FDTensor* output) {
// 1. cls resizes
// 2. normalize
// 3. batch_permute
OcrClassifierResizeImage(mat, cls_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool Classifier::Postprocess(FDTensor& infer_result,
std::tuple<int, float>* cls_result) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
int label = std::distance(
&out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]]));
float score =
float(*std::max_element(&out_data[0], &out_data[output_shape[1]]));
std::get<0>(*cls_result) = label;
std::get<1>(*cls_result) = score;
return true;
}
bool Classifier::Predict(cv::Mat* img, std::tuple<int, float>* cls_result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0])) {
FDERROR << "Failed to preprocess input image." << std::endl;
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], cls_result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}

27
fastdeploy/vision/ocr/ppocr/classifier.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
namespace fastdeploy {
namespace vision {
@@ -41,29 +43,22 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name
std::string ModelName() const { return "ppocr/ocr_cls"; }
/** \brief Predict the input image and get OCR classification model result.
/** \brief BatchPredict the input image and get OCR classification model cls_result.
*
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] result The output of OCR classification model result will be writen to this structure.
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] cls_results The output of OCR classification model cls_result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(cv::Mat* img, std::tuple<int, float>* result);
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores);
// Pre & Post parameters
float cls_thresh;
std::vector<int> cls_image_shape;
int cls_batch_num;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
ClassifierPreprocessor preprocessor_;
ClassifierPostprocessor postprocessor_;
private:
bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* img, FDTensor* output);
/// Postprocess the inferenced results, and set the final result to `result`
bool Postprocess(FDTensor& infer_result, std::tuple<int, float>* result);
};
} // namespace ocr

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
ClassifierPostprocessor::ClassifierPostprocessor() {
initialized_ = true;
}
bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) {
*cls_label = std::distance(
&out_data[0], std::max_element(&out_data[0], &out_data[length]));
*cls_score =
float(*std::max_element(&out_data[0], &out_data[length]));
return true;
}
bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// Classifier have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For Classifier, the output tensor shape = [batch,2]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
cls_labels->resize(batch);
cls_scores->resize(batch);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false;
tensor_data = tensor_data + length;
}
return true;
}
} // namespace classification
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for Classifier serials model.
*/
class FASTDEPLOY_DECL ClassifierPostprocessor {
public:
/** \brief Create a postprocessor instance for Classifier serials model
*
*/
ClassifierPostprocessor();
/** \brief Process the result of runtime and fill to ClassifyResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] cls_labels The output result of classification
* \param[in] cls_scores The output result of classification
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores);
float cls_thresh_ = 0.9;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,88 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
ClassifierPreprocessor::ClassifierPreprocessor() {
initialized_ = true;
}
void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape) {
int imgC = cls_image_shape[0];
int imgH = cls_image_shape[1];
int imgW = cls_image_shape[2];
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
}
}
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrClassifierResizeImage(mat, cls_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for Classifier serials model.
*/
class FASTDEPLOY_DECL ClassifierPreprocessor {
public:
/** \brief Create a preprocessor instance for Classifier serials model
*
*/
ClassifierPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
std::vector<int> cls_image_shape_ = {3, 48, 192};
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

156
fastdeploy/vision/ocr/ppocr/dbdetector.cc Normal file → Executable file
View File

@@ -43,158 +43,50 @@ DBDetector::DBDetector(const std::string& model_file,
// Init
bool DBDetector::Initialize() {
// pre&post process parameters
max_side_len = 960;
det_db_thresh = 0.3;
det_db_box_thresh = 0.6;
det_db_unclip_ratio = 1.5;
det_db_score_mode = "slow";
use_dilation = false;
mean = {0.485f, 0.456f, 0.406f};
scale = {0.229f, 0.224f, 0.225f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
void OcrDetectorResizeImage(Mat* img, int max_size_len, float* ratio_h,
float* ratio_w) {
int w = img->Width();
int h = img->Height();
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
Resize::Run(img, resize_w, resize_h);
*ratio_h = float(resize_h) / float(h);
*ratio_w = float(resize_w) / float(w);
}
bool DBDetector::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
// Resize
OcrDetectorResizeImage(mat, max_side_len, &ratio_h, &ratio_w);
// Normalize
Normalize::Run(mat, mean, scale, true);
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
//-CHW
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool DBDetector::Postprocess(
FDTensor& infer_result, std::vector<std::array<int, 8>>* boxes_result,
const std::map<std::string, std::array<float, 2>>& im_info) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
int n2 = output_shape[2];
int n3 = output_shape[3];
int n = n2 * n3;
float* out_data = static_cast<float*>(infer_result.Data());
// prepare bitmap
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = det_db_thresh * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
if (use_dilation) {
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, bit_map, dila_ele);
}
std::vector<std::vector<std::vector<int>>> boxes;
boxes =
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh,
det_db_unclip_ratio, det_db_score_mode);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info);
// boxes to boxes_result
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
boxes_result->push_back(new_box);
}
return true;
}
bool DBDetector::Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result) {
Mat mat(*img);
if (!Predict(*img, boxes_result)) {
return false;
}
return true;
}
std::vector<FDTensor> input_tensors(1);
bool DBDetector::Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result) {
std::vector<std::vector<std::array<int, 8>>> det_results;
if (!BatchPredict({img}, &det_results)) {
return false;
}
*boxes_result = std::move(det_results[0]);
return true;
}
std::map<std::string, std::array<float, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results) {
std::vector<FDMat> fd_images = WrapMat(images);
std::vector<std::array<int, 4>> batch_det_img_info;
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], boxes_result, im_info)) {
FDERROR << "Failed to post process." << std::endl;
if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false;
}
return true;
}

48
fastdeploy/vision/ocr/ppocr/dbdetector.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
namespace fastdeploy {
namespace vision {
@@ -44,40 +46,34 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
std::string ModelName() const { return "ppocr/ocr_det"; }
/** \brief Predict the input image and get OCR detection model result.
*
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(cv::Mat* im,
virtual bool Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result);
/** \brief Predict the input image and get OCR detection model result.
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result);
/** \brief BatchPredict the input image and get OCR detection model result.
*
* \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] det_results The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results);
// Pre & Post process parameters
int max_side_len;
float ratio_h{};
float ratio_w{};
double det_db_thresh;
double det_db_box_thresh;
double det_db_unclip_ratio;
std::string det_db_score_mode;
bool use_dilation;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
DBDetectorPreprocessor preprocessor_;
DBDetectorPostprocessor postprocessor_;
private:
bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);
/*! @brief Postprocess the inferenced results, and set the final result to `boxes_result`
*/
bool Postprocess(FDTensor& infer_result,
std::vector<std::array<int, 8>>* boxes_result,
const std::map<std::string, std::array<float, 2>>& im_info);
PostProcessor post_processor_;
};
} // namespace ocr

View File

@@ -0,0 +1,110 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
DBDetectorPostprocessor::DBDetectorPostprocessor() {
initialized_ = true;
}
bool DBDetectorPostprocessor::SingleBatchPostprocessor(
const float* out_data,
int n2,
int n3,
const std::array<int,4>& det_img_info,
std::vector<std::array<int, 8>>* boxes_result
) {
int n = n2 * n3;
// prepare bitmap
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
if (use_dilation_) {
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, bit_map, dila_ele);
}
std::vector<std::vector<std::vector<int>>> boxes;
boxes =
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_,
det_db_unclip_ratio_, det_db_score_mode_);
boxes = post_processor_.FilterTagDetRes(boxes, det_img_info);
// boxes to boxes_result
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
boxes_result->push_back(new_box);
}
return true;
}
bool DBDetectorPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int,4>>& batch_det_img_info) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// DBDetector have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For DBDetector, the output tensor shape = [batch, 1, ?, ?]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
results->resize(batch);
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data,
tensor.shape[2],
tensor.shape[3],
batch_det_img_info[i_batch],
&results->at(i_batch)
))return false;
tensor_data = tensor_data + length;
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for DBDetector serials model.
*/
class FASTDEPLOY_DECL DBDetectorPostprocessor {
public:
/** \brief Create a postprocessor instance for DBDetector serials model
*
*/
DBDetectorPostprocessor();
/** \brief Process the result of runtime and fill to results structure
*
* \param[in] tensors The inference result from runtime
* \param[in] results The output result of detector
* \param[in] batch_det_img_info The detector_preprocess result
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int, 4>>& batch_det_img_info);
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.6;
double det_db_unclip_ratio_ = 1.5;
std::string det_db_score_mode_ = "slow";
bool use_dilation_ = false;
private:
bool initialized_ = false;
PostProcessor post_processor_;
bool SingleBatchPostprocessor(const float* out_data,
int n2,
int n3,
const std::array<int, 4>& det_img_info,
std::vector<std::array<int, 8>>* boxes_result);
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,113 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
DBDetectorPreprocessor::DBDetectorPreprocessor() {
initialized_ = true;
}
std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
int w = img->Width();
int h = img->Height();
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
return {w,h,resize_w,resize_h};
/*
*ratio_h = float(resize_h) / float(h);
*ratio_w = float(resize_w) / float(w);
*/
}
bool OcrDetectorResizeImage(FDMat* img,
int resize_w,
int resize_h,
int max_resize_w,
int max_resize_h) {
Resize::Run(img, resize_w, resize_h);
std::vector<float> value = {0, 0, 0};
Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value);
return true;
}
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
int max_resize_w = 0;
int max_resize_h = 0;
std::vector<std::array<int, 4>>& batch_det_img_info = *batch_det_img_info_ptr;
batch_det_img_info.clear();
batch_det_img_info.resize(images->size());
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_);
max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]);
max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]);
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,54 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for DBDetector serials model.
*/
class FASTDEPLOY_DECL DBDetectorPreprocessor {
public:
/** \brief Create a preprocessor instance for DBDetector serials model
*
*/
DBDetectorPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \param[in] batch_det_img_info_ptr The output of preprocess
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
int max_side_len_ = 960;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
bool is_scale_ = true;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

166
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc Normal file → Executable file
View File

@@ -16,45 +16,171 @@
namespace fastdeploy {
void BindPPOCRModel(pybind11::module& m) {
m.def("sort_boxes", [](std::vector<std::array<int, 8>>& boxes) {
vision::ocr::SortBoxes(&boxes);
return boxes;
});
// DBDetector
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_);
.def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len)
.def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh)
.def_readwrite("det_db_box_thresh",
&vision::ocr::DBDetector::det_db_box_thresh)
.def_readwrite("det_db_unclip_ratio",
&vision::ocr::DBDetector::det_db_unclip_ratio)
.def_readwrite("det_db_score_mode",
&vision::ocr::DBDetector::det_db_score_mode)
.def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation)
.def_readwrite("mean", &vision::ocr::DBDetector::mean)
.def_readwrite("scale", &vision::ocr::DBDetector::scale)
.def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale);
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
.def(pybind11::init<>())
.def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
.def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_)
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
std::vector<std::array<int, 4>> batch_det_img_info;
self.Run(&images, &outputs, &batch_det_img_info);
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return make_pair(outputs, batch_det_img_info);
});
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
.def(pybind11::init<>())
.def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_)
.def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_)
.def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_)
.def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_)
.def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_)
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
std::vector<FDTensor>& inputs,
const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results;
if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
}
return results;
})
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
std::vector<pybind11::array>& input_array,
const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
}
return results;
});
// Classifier
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_);
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_)
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPreprocessor.')");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_)
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<int> cls_labels;
std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
}
return make_pair(cls_labels,cls_scores);
})
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
std::vector<int> cls_labels;
std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
}
return make_pair(cls_labels,cls_scores);
});
.def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh)
.def_readwrite("cls_image_shape",
&vision::ocr::Classifier::cls_image_shape)
.def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num);
// Recognizer
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_);
.def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h)
.def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w)
.def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num);
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
.def(pybind11::init<>())
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPreprocessor.')");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor")
.def(pybind11::init<std::string>())
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<std::string> texts;
std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
}
return make_pair(texts, rec_scores);
})
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
std::vector<std::string> texts;
std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
}
return make_pair(texts, rec_scores);
});
}
} // namespace fastdeploy

18
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc Normal file → Executable file
View File

@@ -31,6 +31,15 @@ void BindPPOCRv3(pybind11::module& m) {
vision::OCRResult res;
self.Predict(&mat, &res);
return res;
})
.def("batch_predict", [](pipeline::PPOCRv3& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::OCRResult> results;
self.BatchPredict(images, &results);
return results;
});
}
@@ -49,6 +58,15 @@ void BindPPOCRv2(pybind11::module& m) {
vision::OCRResult res;
self.Predict(&mat, &res);
return res;
})
.def("batch_predict", [](pipeline::PPOCRv2& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::OCRResult> results;
self.BatchPredict(images, &results);
return results;
});
}

130
fastdeploy/vision/ocr/ppocr/ppocr_v2.cc Normal file → Executable file
View File

@@ -22,13 +22,15 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Classifier* cls_model,
fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
recognizer_->rec_image_shape[1] = 32;
Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
}
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), recognizer_(rec_model) {
recognizer_->rec_image_shape[1] = 32;
Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
}
bool PPOCRv2::Initialized() const {
@@ -47,76 +49,68 @@ bool PPOCRv2::Initialized() const {
return true;
}
bool PPOCRv2::Detect(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
if (!detector_->Predict(img, &(result->boxes))) {
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
return false;
}
vision::ocr::SortBoxes(result);
return true;
}
bool PPOCRv2::Recognize(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::tuple<std::string, float> rec_result;
if (!recognizer_->Predict(img, &rec_result)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}
result->text.push_back(std::get<0>(rec_result));
result->rec_scores.push_back(std::get<1>(rec_result));
return true;
}
bool PPOCRv2::Classify(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::tuple<int, float> cls_result;
if (!classifier_->Predict(img, &cls_result)) {
FDERROR << "There's error while classifying image in PPOCR." << std::endl;
return false;
}
result->cls_labels.push_back(std::get<0>(cls_result));
result->cls_scores.push_back(std::get<1>(cls_result));
return true;
}
bool PPOCRv2::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
result->Clear();
if (nullptr != detector_ && !Detect(img, result)) {
FDERROR << "Failed to detect image." << std::endl;
return false;
}
// Get croped images by detection result
std::vector<cv::Mat> image_list;
for (size_t i = 0; i < result->boxes.size(); ++i) {
auto crop_im = vision::ocr::GetRotateCropImage(*img, (result->boxes)[i]);
image_list.push_back(crop_im);
}
if (result->boxes.size() == 0) {
image_list.push_back(*img);
}
for (size_t i = 0; i < image_list.size(); ++i) {
if (nullptr != classifier_ && !Classify(&(image_list[i]), result)) {
FDERROR << "Failed to classify croped image of index " << i << "." << std::endl;
return false;
}
if (nullptr != classifier_ && result->cls_labels[i] % 2 == 1 && result->cls_scores[i] > classifier_->cls_thresh) {
cv::rotate(image_list[i], image_list[i], 1);
}
if (nullptr != recognizer_ && !Recognize(&(image_list[i]), result)) {
FDERROR << "Failed to recgnize croped image of index " << i << "." << std::endl;
return false;
}
}
std::vector<fastdeploy::vision::OCRResult> batch_result(1);
BatchPredict({*img},&batch_result);
*result = std::move(batch_result[0]);
return true;
};
bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result) {
batch_result->clear();
batch_result->resize(images.size());
std::vector<std::vector<std::array<int, 8>>> batch_boxes(images.size());
if (!detector_->BatchPredict(images, &batch_boxes)) {
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
return false;
}
for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
(*batch_result)[i_batch].boxes = batch_boxes[i_batch];
}
for(int i_batch = 0; i_batch < images.size(); ++i_batch) {
fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch];
// Get croped images by detection result
const std::vector<std::array<int, 8>>& boxes = ocr_result.boxes;
const cv::Mat& img = images[i_batch];
std::vector<cv::Mat> image_list;
if (boxes.size() == 0) {
image_list.emplace_back(img);
}else{
image_list.resize(boxes.size());
for (size_t i_box = 0; i_box < boxes.size(); ++i_box) {
image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]);
}
}
std::vector<int32_t>* cls_labels_ptr = &ocr_result.cls_labels;
std::vector<float>* cls_scores_ptr = &ocr_result.cls_scores;
std::vector<std::string>* text_ptr = &ocr_result.text;
std::vector<float>* rec_scores_ptr = &ocr_result.rec_scores;
if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}else{
for (size_t i_img = 0; i_img < image_list.size(); ++i_img) {
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
cv::rotate(image_list[i_img], image_list[i_img], 1);
}
}
}
if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}
}
return true;
}
} // namesapce pipeline
} // namespace fastdeploy

13
fastdeploy/vision/ocr/ppocr/ppocr_v2.h Normal file → Executable file
View File

@@ -59,6 +59,14 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
* \return true if the prediction successed, otherwise false.
*/
virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/** \brief BatchPredict the input image and get OCR result.
*
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] batch_result The output list of OCR result will be writen to this structure.
* \return true if the prediction successed, otherwise false.
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result);
bool Initialized() const override;
protected:
@@ -66,11 +74,6 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
fastdeploy::vision::ocr::Classifier* classifier_ = nullptr;
fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr;
/// Launch the detection process in OCR.
virtual bool Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/// Launch the recognition process in OCR.
virtual bool Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/// Launch the classification process in OCR.
virtual bool Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
};
namespace application {

4
fastdeploy/vision/ocr/ppocr/ppocr_v3.h Normal file → Executable file
View File

@@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, cls_model, rec_model) {
// The only difference between v2 and v3
recognizer_->rec_image_shape[1] = 48;
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
}
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
*
@@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, rec_model) {
// The only difference between v2 and v3
recognizer_->rec_image_shape[1] = 48;
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
}
};

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
std::vector<std::string> ReadDict(const std::string& path) {
std::ifstream in(path);
FDASSERT(in, "Cannot open file %s to read.", path.c_str());
std::string line;
std::vector<std::string> m_vec;
while (getline(in, line)) {
m_vec.push_back(line);
}
m_vec.insert(m_vec.begin(), "#"); // blank char for ctc
m_vec.push_back(" ");
return m_vec;
}
RecognizerPostprocessor::RecognizerPostprocessor(){
initialized_ = true;
}
RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) {
// init label_lsit
label_list_ = ReadDict(label_path);
initialized_ = true;
}
bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data,
const std::vector<int64_t>& output_shape,
std::string* text, float* rec_score) {
std::string& str_res = *text;
float& score = *rec_score;
score = 0.f;
int argmax_idx;
int last_index = 0;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < output_shape[1]; n++) {
argmax_idx = int(
std::distance(&out_data[n * output_shape[2]],
std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]])));
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
if(argmax_idx > label_list_.size()) {
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
<< label_list_.size() << ". Please check the label file!" << std::endl;
return false;
}
str_res += label_list_[argmax_idx];
}
last_index = argmax_idx;
}
score /= (count + 1e-6);
if (count == 0 || std::isnan(score)) {
score = 0.f;
}
return true;
}
bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// Recognizer have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For Recognizer, the output tensor shape = [batch, ?, 6625]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
texts->resize(batch);
rec_scores->resize(batch);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) {
return false;
}
tensor_data = tensor_data + length;
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for Recognizer serials model.
*/
class FASTDEPLOY_DECL RecognizerPostprocessor {
public:
RecognizerPostprocessor();
/** \brief Create a postprocessor instance for Recognizer serials model
*
* \param[in] label_path The path of label_dict
*/
explicit RecognizerPostprocessor(const std::string& label_path);
/** \brief Process the result of runtime and fill to ClassifyResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] texts The output result of recognizer
* \param[in] rec_scores The output result of recognizer
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores);
private:
bool SingleBatchPostprocessor(const float* out_data,
const std::vector<int64_t>& output_shape,
std::string* text, float* rec_score);
bool initialized_ = false;
std::vector<std::string> label_list_;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
RecognizerPreprocessor::RecognizerPreprocessor() {
initialized_ = true;
}
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(imgH * max_wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW) {
resize_w = imgW;
}else{
resize_w = int(ceilf(imgH * ratio));
}
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
int imgH = rec_image_shape_[1];
int imgW = rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH;
float ori_wh_ratio;
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for PaddleClas serials model.
*/
class FASTDEPLOY_DECL RecognizerPreprocessor {
public:
/** \brief Create a preprocessor instance for PaddleClas serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
*/
RecognizerPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

150
fastdeploy/vision/ocr/ppocr/recognizer.cc Normal file → Executable file
View File

@@ -20,29 +20,13 @@ namespace fastdeploy {
namespace vision {
namespace ocr {
std::vector<std::string> ReadDict(const std::string& path) {
std::ifstream in(path);
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
}
return m_vec;
}
Recognizer::Recognizer() {}
Recognizer::Recognizer(const std::string& model_file,
const std::string& params_file,
const std::string& label_path,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
const ModelFormat& model_format):postprocessor_(label_path) {
if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::ORT,
Backend::OPENVINO};
@@ -56,27 +40,11 @@ Recognizer::Recognizer(const std::string& model_file,
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
// init label_lsit
label_list = ReadDict(label_path);
label_list.insert(label_list.begin(), "#"); // blank char for ctc
label_list.push_back(" ");
}
// Init
bool Recognizer::Initialize() {
// pre&post process parameters
rec_batch_num = 1;
rec_img_h = 48;
rec_img_w = 320;
rec_image_shape = {3, rec_img_h, rec_img_w};
mean = {0.5f, 0.5f, 0.5f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
@@ -85,119 +53,23 @@ bool Recognizer::Initialize() {
return true;
}
void OcrRecognizerResizeImage(Mat* mat, const float& wh_ratio,
const std::vector<int>& rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(imgH * wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {127, 127, 127};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
}
bool Recognizer::Preprocess(Mat* mat, FDTensor* output,
const std::vector<int>& rec_image_shape) {
int imgH = rec_image_shape[1];
int imgW = rec_image_shape[2];
float wh_ratio = imgW * 1.0 / imgH;
float ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
wh_ratio = std::max(wh_ratio, ori_wh_ratio);
OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool Recognizer::Postprocess(FDTensor& infer_result,
std::tuple<std::string, float>* rec_result) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
std::string str_res;
int argmax_idx;
int last_index = 0;
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < output_shape[1]; n++) {
argmax_idx = int(
std::distance(&out_data[n * output_shape[2]],
std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]])));
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
if(argmax_idx > label_list.size()){
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
<< label_list.size() << ". Please check the label file!" << std::endl;
bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
str_res += label_list[argmax_idx];
}
last_index = argmax_idx;
}
score /= (count + 1e-6);
if (count == 0 || std::isnan(score)) {
score = 0.f;
}
std::get<0>(*rec_result) = str_res;
std::get<1>(*rec_result) = score;
return true;
}
bool Recognizer::Predict(cv::Mat* img,
std::tuple<std::string, float>* rec_result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) {
FDERROR << "Failed to preprocess input image." << std::endl;
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) {
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], rec_result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}

31
fastdeploy/vision/ocr/ppocr/recognizer.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
namespace fastdeploy {
namespace vision {
@@ -43,35 +45,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name
std::string ModelName() const { return "ppocr/ocr_rec"; }
/** \brief Predict the input image and get OCR recognition model result.
/** \brief BatchPredict the input image and get OCR recognition model result.
*
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] rec_result The output of OCR recognition model result will be writen to this structure.
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] rec_results The output of OCR recognition model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(cv::Mat* img,
std::tuple<std::string, float>* rec_result);
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores);
// Pre & Post parameters
std::vector<std::string> label_list;
int rec_batch_num;
int rec_img_h;
int rec_img_w;
std::vector<int> rec_image_shape;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
RecognizerPreprocessor preprocessor_;
RecognizerPostprocessor postprocessor_;
private:
bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* img, FDTensor* outputs,
const std::vector<int>& rec_image_shape);
/*! @brief Postprocess the inferenced results, and set the final result to `rec_result`
*/
bool Postprocess(FDTensor& infer_result,
std::tuple<std::string, float>* rec_result);
};
} // namespace ocr

10
fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc Normal file → Executable file
View File

@@ -318,10 +318,12 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
}
std::vector<std::vector<std::vector<int>>> PostProcessor::FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w, const std::map<std::string, std::array<float, 2>> &im_info) {
int oriimg_h = im_info.at("input_shape")[0];
int oriimg_w = im_info.at("input_shape")[1];
std::vector<std::vector<std::vector<int>>> boxes,
const std::array<int,4>& det_img_info) {
int oriimg_w = det_img_info[0];
int oriimg_h = det_img_info[1];
float ratio_w = float(det_img_info[2])/float(oriimg_w);
float ratio_h = float(det_img_info[3])/float(oriimg_h);
std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < boxes.size(); n++) {

View File

@@ -57,9 +57,8 @@ class PostProcessor {
const float &det_db_unclip_ratio, const std::string &det_db_score_mode);
std::vector<std::vector<std::vector<int>>> FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w,
const std::map<std::string, std::array<float, 2>> &im_info);
std::vector<std::vector<std::vector<int>>> boxes,
const std::array<int, 4>& det_img_info);
private:
static bool XsortInt(std::vector<int> a, std::vector<int> b);

View File

@@ -28,10 +28,10 @@ namespace fastdeploy {
namespace vision {
namespace ocr {
cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
const std::array<int, 8>& box);
void SortBoxes(OCRResult* result);
FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes);
} // namespace ocr
} // namespace vision

View File

@@ -29,17 +29,17 @@ bool CompareBox(const std::array<int, 8>& result1,
}
}
void SortBoxes(OCRResult* result) {
std::sort(result->boxes.begin(), result->boxes.end(), CompareBox);
void SortBoxes(std::vector<std::array<int, 8>>* boxes) {
std::sort(boxes->begin(), boxes->end(), CompareBox);
if (result->boxes.size() == 0) {
if (boxes->size() == 0) {
return;
}
for (int i = 0; i < result->boxes.size() - 1; i++) {
if (abs(result->boxes[i + 1][1] - result->boxes[i][1]) < 10 &&
(result->boxes[i + 1][0] < result->boxes[i][0])) {
std::swap(result->boxes[i], result->boxes[i + 1]);
for (int i = 0; i < boxes->size() - 1; i++) {
if (abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 &&
((*boxes)[i + 1][0] < (*boxes)[i][0])) {
std::swap((*boxes)[i], (*boxes)[i + 1]);
}
}
}

8
python/fastdeploy/vision/ocr/__init__.py Normal file → Executable file
View File

@@ -13,10 +13,4 @@
# limitations under the License.
from __future__ import absolute_import
from .ppocr import PPOCRv3
from .ppocr import PPOCRv2
from .ppocr import PPOCRSystemv3
from .ppocr import PPOCRSystemv2
from .ppocr import DBDetector
from .ppocr import Classifier
from .ppocr import Recognizer
from .ppocr import *

172
python/fastdeploy/vision/ocr/ppocr/__init__.py Normal file → Executable file
View File

@@ -41,40 +41,11 @@ class DBDetector(FastDeployModel):
assert self.initialized, "DBDetector initialize failed."
# 一些跟DBDetector模型有关的属性封装
@property
def max_side_len(self):
return self._model.max_side_len
'''
@property
def det_db_thresh(self):
return self._model.det_db_thresh
@property
def det_db_box_thresh(self):
return self._model.det_db_box_thresh
@property
def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio
@property
def det_db_score_mode(self):
return self._model.det_db_score_mode
@property
def use_dilation(self):
return self._model.use_dilation
@property
def is_scale(self):
return self._model.max_wh
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.max_side_len = value
@det_db_thresh.setter
def det_db_thresh(self, value):
assert isinstance(
@@ -82,6 +53,10 @@ class DBDetector(FastDeployModel):
float), "The value to set `det_db_thresh` must be type of float."
self._model.det_db_thresh = value
@property
def det_db_box_thresh(self):
return self._model.det_db_box_thresh
@det_db_box_thresh.setter
def det_db_box_thresh(self, value):
assert isinstance(
@@ -89,6 +64,10 @@ class DBDetector(FastDeployModel):
), "The value to set `det_db_box_thresh` must be type of float."
self._model.det_db_box_thresh = value
@property
def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio
@det_db_unclip_ratio.setter
def det_db_unclip_ratio(self, value):
assert isinstance(
@@ -96,6 +75,10 @@ class DBDetector(FastDeployModel):
), "The value to set `det_db_unclip_ratio` must be type of float."
self._model.det_db_unclip_ratio = value
@property
def det_db_score_mode(self):
return self._model.det_db_score_mode
@det_db_score_mode.setter
def det_db_score_mode(self, value):
assert isinstance(
@@ -103,6 +86,10 @@ class DBDetector(FastDeployModel):
str), "The value to set `det_db_score_mode` must be type of str."
self._model.det_db_score_mode = value
@property
def use_dilation(self):
return self._model.use_dilation
@use_dilation.setter
def use_dilation(self, value):
assert isinstance(
@@ -110,11 +97,26 @@ class DBDetector(FastDeployModel):
bool), "The value to set `use_dilation` must be type of bool."
self._model.use_dilation = value
@property
def max_side_len(self):
return self._model.max_side_len
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.max_side_len = value
@property
def is_scale(self):
return self._model.max_wh
@is_scale.setter
def is_scale(self, value):
assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool."
self._model.is_scale = value
'''
class Classifier(FastDeployModel):
@@ -139,6 +141,7 @@ class Classifier(FastDeployModel):
model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "Classifier initialize failed."
'''
@property
def cls_thresh(self):
return self._model.cls_thresh
@@ -170,6 +173,7 @@ class Classifier(FastDeployModel):
value,
int), "The value to set `cls_batch_num` must be type of int."
self._model.cls_batch_num = value
'''
class Recognizer(FastDeployModel):
@@ -197,6 +201,7 @@ class Recognizer(FastDeployModel):
model_format)
assert self.initialized, "Recognizer initialize failed."
'''
@property
def rec_img_h(self):
return self._model.rec_img_h
@@ -227,6 +232,7 @@ class Recognizer(FastDeployModel):
value,
int), "The value to set `rec_batch_num` must be type of int."
self._model.rec_batch_num = value
'''
class PPOCRv3(FastDeployModel):
@@ -253,6 +259,14 @@ class PPOCRv3(FastDeployModel):
"""
return self.system.predict(input_image)
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult
"""
return self.system.batch_predict(images)
class PPOCRSystemv3(PPOCRv3):
def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -289,6 +303,14 @@ class PPOCRv2(FastDeployModel):
"""
return self.system.predict(input_image)
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult
"""
return self.system.batch_predict(images)
class PPOCRSystemv2(PPOCRv2):
def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -299,3 +321,93 @@ class PPOCRSystemv2(PPOCRv2):
def predict(self, input_image):
return super(PPOCRSystemv2, self).predict(input_image)
class DBDetectorPreprocessor:
def __init__(self):
"""Create a preprocessor for DBDetectorModel
"""
self._preprocessor = C.vision.ocr.DBDetectorPreprocessor()
def run(self, input_ims):
"""Preprocess input images for DBDetectorModel
:param: input_ims: (list of numpy.ndarray) The input image
:return: pair(list of FDTensor, list of std::array<int, 4>)
"""
return self._preprocessor.run(input_ims)
class DBDetectorPostprocessor:
def __init__(self):
"""Create a postprocessor for DBDetectorModel
"""
self._postprocessor = C.vision.ocr.DBDetectorPostprocessor()
def run(self, runtime_results, batch_det_img_info):
"""Postprocess the runtime results for DBDetectorModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:param: batch_det_img_info: (list of std::array<int, 4>)The output of det_preprocessor
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, batch_det_img_info)
class RecognizerPreprocessor:
def __init__(self):
"""Create a preprocessor for RecognizerModel
"""
self._preprocessor = C.vision.ocr.RecognizerPreprocessor()
def run(self, input_ims):
"""Preprocess input images for RecognizerModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class RecognizerPostprocessor:
def __init__(self, label_path):
"""Create a postprocessor for RecognizerModel
:param label_path: (str)Path of label file
"""
self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path)
def run(self, runtime_results):
"""Postprocess the runtime results for RecognizerModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class ClassifierPreprocessor:
def __init__(self):
"""Create a preprocessor for ClassifierModel
"""
self._preprocessor = C.vision.ocr.ClassifierPreprocessor()
def run(self, input_ims):
"""Preprocess input images for ClassifierModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class ClassifierPostprocessor:
def __init__(self):
"""Create a postprocessor for ClassifierModel
"""
self._postprocessor = C.vision.ocr.ClassifierPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for ClassifierModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
def sort_boxes(boxes):
return C.vision.ocr.sort_boxes(boxes)