[Model] change ocr pre and post (#568)

* change ocr pre and post

* add pybind

* change ocr

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix copy bug

* fix code style

* fix bug

* add new function

* fix windows ci bug
This commit is contained in:
Thomas Young
2022-11-18 13:17:42 +08:00
committed by GitHub
parent 1609ce1bab
commit 143506b654
31 changed files with 1402 additions and 569 deletions

View File

@@ -22,6 +22,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <numeric>
#if defined(_WIN32) #if defined(_WIN32)
#ifdef FASTDEPLOY_LIB #ifdef FASTDEPLOY_LIB

View File

@@ -48,6 +48,7 @@
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h" #include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h" #include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h"

93
fastdeploy/vision/ocr/ppocr/classifier.cc Normal file → Executable file
View File

@@ -41,16 +41,7 @@ Classifier::Classifier(const std::string& model_file,
initialized = Initialize(); initialized = Initialize();
} }
// Init
bool Classifier::Initialize() { bool Classifier::Initialize() {
// pre&post process parameters
cls_thresh = 0.9;
cls_image_shape = {3, 48, 192};
cls_batch_num = 1;
mean = {0.5f, 0.5f, 0.5f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) { if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl; FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false; return false;
@@ -59,85 +50,23 @@ bool Classifier::Initialize() {
return true; return true;
} }
void OcrClassifierResizeImage(Mat* mat, bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
const std::vector<int>& rec_image_shape) { std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores) {
int imgC = rec_image_shape[0]; std::vector<FDMat> fd_images = WrapMat(images);
int imgH = rec_image_shape[1]; if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
int imgW = rec_image_shape[2]; FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
} }
} reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
bool Classifier::Preprocess(Mat* mat, FDTensor* output) { FDERROR << "Failed to inference by runtime." << std::endl;
// 1. cls resizes
// 2. normalize
// 3. batch_permute
OcrClassifierResizeImage(mat, cls_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool Classifier::Postprocess(FDTensor& infer_result,
std::tuple<int, float>* cls_result) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
int label = std::distance(
&out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]]));
float score =
float(*std::max_element(&out_data[0], &out_data[output_shape[1]]));
std::get<0>(*cls_result) = label;
std::get<1>(*cls_result) = score;
return true;
}
bool Classifier::Predict(cv::Mat* img, std::tuple<int, float>* cls_result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0])) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false; return false;
} }
input_tensors[0].name = InputInfoOfRuntime(0).name; if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) {
std::vector<FDTensor> output_tensors; FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
if (!Postprocess(output_tensors[0], cls_result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true; return true;
} }

27
fastdeploy/vision/ocr/ppocr/classifier.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -41,29 +43,22 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_cls"; } std::string ModelName() const { return "ppocr/ocr_cls"; }
/** \brief Predict the input image and get OCR classification model result.
/** \brief BatchPredict the input image and get OCR classification model cls_result.
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] result The output of OCR classification model result will be writen to this structure. * \param[in] cls_results The output of OCR classification model cls_result will be writen to this structure.
* \return true if the prediction is successed, otherwise false. * \return true if the prediction is successed, otherwise false.
*/ */
virtual bool Predict(cv::Mat* img, std::tuple<int, float>* result); virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores);
// Pre & Post parameters ClassifierPreprocessor preprocessor_;
float cls_thresh; ClassifierPostprocessor postprocessor_;
std::vector<int> cls_image_shape;
int cls_batch_num;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private: private:
bool Initialize(); bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* img, FDTensor* output);
/// Postprocess the inferenced results, and set the final result to `result`
bool Postprocess(FDTensor& infer_result, std::tuple<int, float>* result);
}; };
} // namespace ocr } // namespace ocr

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
ClassifierPostprocessor::ClassifierPostprocessor() {
initialized_ = true;
}
bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) {
*cls_label = std::distance(
&out_data[0], std::max_element(&out_data[0], &out_data[length]));
*cls_score =
float(*std::max_element(&out_data[0], &out_data[length]));
return true;
}
bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels,
std::vector<float>* cls_scores) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// Classifier have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For Classifier, the output tensor shape = [batch,2]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
cls_labels->resize(batch);
cls_scores->resize(batch);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false;
tensor_data = tensor_data + length;
}
return true;
}
} // namespace classification
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for Classifier serials model.
*/
class FASTDEPLOY_DECL ClassifierPostprocessor {
public:
/** \brief Create a postprocessor instance for Classifier serials model
*
*/
ClassifierPostprocessor();
/** \brief Process the result of runtime and fill to ClassifyResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] cls_labels The output result of classification
* \param[in] cls_scores The output result of classification
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores);
float cls_thresh_ = 0.9;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,88 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
ClassifierPreprocessor::ClassifierPreprocessor() {
initialized_ = true;
}
void OcrClassifierResizeImage(FDMat* mat,
const std::vector<int>& cls_image_shape) {
int imgC = cls_image_shape[0];
int imgH = cls_image_shape[1];
int imgW = cls_image_shape[2];
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
if (resize_w < imgW) {
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
}
}
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrClassifierResizeImage(mat, cls_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for Classifier serials model.
*/
class FASTDEPLOY_DECL ClassifierPreprocessor {
public:
/** \brief Create a preprocessor instance for Classifier serials model
*
*/
ClassifierPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
std::vector<int> cls_image_shape_ = {3, 48, 192};
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

156
fastdeploy/vision/ocr/ppocr/dbdetector.cc Normal file → Executable file
View File

@@ -43,158 +43,50 @@ DBDetector::DBDetector(const std::string& model_file,
// Init // Init
bool DBDetector::Initialize() { bool DBDetector::Initialize() {
// pre&post process parameters
max_side_len = 960;
det_db_thresh = 0.3;
det_db_box_thresh = 0.6;
det_db_unclip_ratio = 1.5;
det_db_score_mode = "slow";
use_dilation = false;
mean = {0.485f, 0.456f, 0.406f};
scale = {0.229f, 0.224f, 0.225f};
is_scale = true;
if (!InitRuntime()) { if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl; FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false; return false;
} }
return true;
}
void OcrDetectorResizeImage(Mat* img, int max_size_len, float* ratio_h,
float* ratio_w) {
int w = img->Width();
int h = img->Height();
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
Resize::Run(img, resize_w, resize_h);
*ratio_h = float(resize_h) / float(h);
*ratio_w = float(resize_w) / float(w);
}
bool DBDetector::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
// Resize
OcrDetectorResizeImage(mat, max_side_len, &ratio_h, &ratio_w);
// Normalize
Normalize::Run(mat, mean, scale, true);
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
//-CHW
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool DBDetector::Postprocess(
FDTensor& infer_result, std::vector<std::array<int, 8>>* boxes_result,
const std::map<std::string, std::array<float, 2>>& im_info) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
int n2 = output_shape[2];
int n3 = output_shape[3];
int n = n2 * n3;
float* out_data = static_cast<float*>(infer_result.Data());
// prepare bitmap
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = det_db_thresh * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
if (use_dilation) {
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, bit_map, dila_ele);
}
std::vector<std::vector<std::vector<int>>> boxes;
boxes =
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh,
det_db_unclip_ratio, det_db_score_mode);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info);
// boxes to boxes_result
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
boxes_result->push_back(new_box);
}
return true; return true;
} }
bool DBDetector::Predict(cv::Mat* img, bool DBDetector::Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result) { std::vector<std::array<int, 8>>* boxes_result) {
Mat mat(*img); if (!Predict(*img, boxes_result)) {
return false;
}
return true;
}
std::vector<FDTensor> input_tensors(1); bool DBDetector::Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result) {
std::vector<std::vector<std::array<int, 8>>> det_results;
if (!BatchPredict({img}, &det_results)) {
return false;
}
*boxes_result = std::move(det_results[0]);
return true;
}
std::map<std::string, std::array<float, 2>> im_info; bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results) {
// Record the shape of image and the shape of preprocessed image std::vector<FDMat> fd_images = WrapMat(images);
im_info["input_shape"] = {static_cast<float>(mat.Height()), std::vector<std::array<int, 4>> batch_det_img_info;
static_cast<float>(mat.Width())}; if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) {
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl; FDERROR << "Failed to preprocess input image." << std::endl;
return false; return false;
} }
input_tensors[0].name = InputInfoOfRuntime(0).name; reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors; if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
if (!Infer(input_tensors, &output_tensors)) { FDERROR << "Failed to inference by runtime." << std::endl;
FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
if (!Postprocess(output_tensors[0], boxes_result, im_info)) { if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) {
FDERROR << "Failed to post process." << std::endl; FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
return false; return false;
} }
return true; return true;
} }

48
fastdeploy/vision/ocr/ppocr/dbdetector.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -44,40 +46,34 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
std::string ModelName() const { return "ppocr/ocr_det"; } std::string ModelName() const { return "ppocr/ocr_det"; }
/** \brief Predict the input image and get OCR detection model result. /** \brief Predict the input image and get OCR detection model result.
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure. * \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false. * \return true if the prediction is successed, otherwise false.
*/ */
virtual bool Predict(cv::Mat* im, virtual bool Predict(cv::Mat* img,
std::vector<std::array<int, 8>>* boxes_result); std::vector<std::array<int, 8>>* boxes_result);
/** \brief Predict the input image and get OCR detection model result.
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result);
/** \brief BatchPredict the input image and get OCR detection model result.
*
* \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] det_results The output of OCR detection model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false.
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::vector<std::array<int, 8>>>* det_results);
// Pre & Post process parameters DBDetectorPreprocessor preprocessor_;
int max_side_len; DBDetectorPostprocessor postprocessor_;
float ratio_h{};
float ratio_w{};
double det_db_thresh;
double det_db_box_thresh;
double det_db_unclip_ratio;
std::string det_db_score_mode;
bool use_dilation;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private: private:
bool Initialize(); bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);
/*! @brief Postprocess the inferenced results, and set the final result to `boxes_result`
*/
bool Postprocess(FDTensor& infer_result,
std::vector<std::array<int, 8>>* boxes_result,
const std::map<std::string, std::array<float, 2>>& im_info);
PostProcessor post_processor_;
}; };
} // namespace ocr } // namespace ocr

View File

@@ -0,0 +1,110 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
DBDetectorPostprocessor::DBDetectorPostprocessor() {
initialized_ = true;
}
bool DBDetectorPostprocessor::SingleBatchPostprocessor(
const float* out_data,
int n2,
int n3,
const std::array<int,4>& det_img_info,
std::vector<std::array<int, 8>>* boxes_result
) {
int n = n2 * n3;
// prepare bitmap
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
const double threshold = det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
if (use_dilation_) {
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, bit_map, dila_ele);
}
std::vector<std::vector<std::vector<int>>> boxes;
boxes =
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_,
det_db_unclip_ratio_, det_db_score_mode_);
boxes = post_processor_.FilterTagDetRes(boxes, det_img_info);
// boxes to boxes_result
for (int i = 0; i < boxes.size(); i++) {
std::array<int, 8> new_box;
int k = 0;
for (auto& vec : boxes[i]) {
for (auto& e : vec) {
new_box[k++] = e;
}
}
boxes_result->push_back(new_box);
}
return true;
}
bool DBDetectorPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int,4>>& batch_det_img_info) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// DBDetector have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For DBDetector, the output tensor shape = [batch, 1, ?, ?]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
results->resize(batch);
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data,
tensor.shape[2],
tensor.shape[3],
batch_det_img_info[i_batch],
&results->at(i_batch)
))return false;
tensor_data = tensor_data + length;
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for DBDetector serials model.
*/
class FASTDEPLOY_DECL DBDetectorPostprocessor {
public:
/** \brief Create a postprocessor instance for DBDetector serials model
*
*/
DBDetectorPostprocessor();
/** \brief Process the result of runtime and fill to results structure
*
* \param[in] tensors The inference result from runtime
* \param[in] results The output result of detector
* \param[in] batch_det_img_info The detector_preprocess result
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::vector<std::array<int, 8>>>* results,
const std::vector<std::array<int, 4>>& batch_det_img_info);
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.6;
double det_db_unclip_ratio_ = 1.5;
std::string det_db_score_mode_ = "slow";
bool use_dilation_ = false;
private:
bool initialized_ = false;
PostProcessor post_processor_;
bool SingleBatchPostprocessor(const float* out_data,
int n2,
int n3,
const std::array<int, 4>& det_img_info,
std::vector<std::array<int, 8>>* boxes_result);
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,113 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
DBDetectorPreprocessor::DBDetectorPreprocessor() {
initialized_ = true;
}
std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
int w = img->Width();
int h = img->Height();
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
}
}
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
return {w,h,resize_w,resize_h};
/*
*ratio_h = float(resize_h) / float(h);
*ratio_w = float(resize_w) / float(w);
*/
}
bool OcrDetectorResizeImage(FDMat* img,
int resize_w,
int resize_h,
int max_resize_w,
int max_resize_h) {
Resize::Run(img, resize_w, resize_h);
std::vector<float> value = {0, 0, 0};
Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value);
return true;
}
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
int max_resize_w = 0;
int max_resize_h = 0;
std::vector<std::array<int, 4>>& batch_det_img_info = *batch_det_img_info_ptr;
batch_det_img_info.clear();
batch_det_img_info.resize(images->size());
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_);
max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]);
max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]);
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,54 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for DBDetector serials model.
*/
class FASTDEPLOY_DECL DBDetectorPreprocessor {
public:
/** \brief Create a preprocessor instance for DBDetector serials model
*
*/
DBDetectorPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \param[in] batch_det_img_info_ptr The output of preprocess
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs,
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
int max_side_len_ = 960;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
bool is_scale_ = true;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

166
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc Normal file → Executable file
View File

@@ -16,45 +16,171 @@
namespace fastdeploy { namespace fastdeploy {
void BindPPOCRModel(pybind11::module& m) { void BindPPOCRModel(pybind11::module& m) {
m.def("sort_boxes", [](std::vector<std::array<int, 8>>& boxes) {
vision::ocr::SortBoxes(&boxes);
return boxes;
});
// DBDetector // DBDetector
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector") pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
.def(pybind11::init<std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_);
.def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len) pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
.def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh) .def(pybind11::init<>())
.def_readwrite("det_db_box_thresh", .def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
&vision::ocr::DBDetector::det_db_box_thresh) .def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_)
.def_readwrite("det_db_unclip_ratio", .def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_)
&vision::ocr::DBDetector::det_db_unclip_ratio) .def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_)
.def_readwrite("det_db_score_mode", .def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
&vision::ocr::DBDetector::det_db_score_mode) std::vector<vision::FDMat> images;
.def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation) for (size_t i = 0; i < im_list.size(); ++i) {
.def_readwrite("mean", &vision::ocr::DBDetector::mean) images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
.def_readwrite("scale", &vision::ocr::DBDetector::scale) }
.def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale); std::vector<FDTensor> outputs;
std::vector<std::array<int, 4>> batch_det_img_info;
self.Run(&images, &outputs, &batch_det_img_info);
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return make_pair(outputs, batch_det_img_info);
});
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
.def(pybind11::init<>())
.def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_)
.def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_)
.def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_)
.def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_)
.def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_)
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
std::vector<FDTensor>& inputs,
const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results;
if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
}
return results;
})
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
std::vector<pybind11::array>& input_array,
const std::vector<std::array<int, 4>>& batch_det_img_info) {
std::vector<std::vector<std::array<int, 8>>> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, batch_det_img_info)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
}
return results;
});
// Classifier // Classifier
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier") pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
.def(pybind11::init<std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_);
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_)
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPreprocessor.')");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
.def(pybind11::init<>())
.def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_)
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<int> cls_labels;
std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
}
return make_pair(cls_labels,cls_scores);
})
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
std::vector<int> cls_labels;
std::vector<float> cls_scores;
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
}
return make_pair(cls_labels,cls_scores);
});
.def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh)
.def_readwrite("cls_image_shape",
&vision::ocr::Classifier::cls_image_shape)
.def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num);
// Recognizer // Recognizer
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer") pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption, .def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>()) ModelFormat>())
.def(pybind11::init<>()) .def(pybind11::init<>())
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_);
.def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h) pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
.def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w) .def(pybind11::init<>())
.def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num); .def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPreprocessor.')");
}
for(size_t i = 0; i< outputs.size(); ++i){
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor")
.def(pybind11::init<std::string>())
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<std::string> texts;
std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
}
return make_pair(texts, rec_scores);
})
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
std::vector<std::string> texts;
std::vector<float> rec_scores;
if (!self.Run(inputs, &texts, &rec_scores)) {
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
}
return make_pair(texts, rec_scores);
});
} }
} // namespace fastdeploy } // namespace fastdeploy

18
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc Normal file → Executable file
View File

@@ -31,6 +31,15 @@ void BindPPOCRv3(pybind11::module& m) {
vision::OCRResult res; vision::OCRResult res;
self.Predict(&mat, &res); self.Predict(&mat, &res);
return res; return res;
})
.def("batch_predict", [](pipeline::PPOCRv3& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::OCRResult> results;
self.BatchPredict(images, &results);
return results;
}); });
} }
@@ -49,6 +58,15 @@ void BindPPOCRv2(pybind11::module& m) {
vision::OCRResult res; vision::OCRResult res;
self.Predict(&mat, &res); self.Predict(&mat, &res);
return res; return res;
})
.def("batch_predict", [](pipeline::PPOCRv2& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::OCRResult> results;
self.BatchPredict(images, &results);
return results;
}); });
} }

130
fastdeploy/vision/ocr/ppocr/ppocr_v2.cc Normal file → Executable file
View File

@@ -22,13 +22,15 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Classifier* cls_model, fastdeploy::vision::ocr::Classifier* cls_model,
fastdeploy::vision::ocr::Recognizer* rec_model) fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) { : detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
recognizer_->rec_image_shape[1] = 32; Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
} }
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Recognizer* rec_model) fastdeploy::vision::ocr::Recognizer* rec_model)
: detector_(det_model), recognizer_(rec_model) { : detector_(det_model), recognizer_(rec_model) {
recognizer_->rec_image_shape[1] = 32; Initialized();
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
} }
bool PPOCRv2::Initialized() const { bool PPOCRv2::Initialized() const {
@@ -47,76 +49,68 @@ bool PPOCRv2::Initialized() const {
return true; return true;
} }
bool PPOCRv2::Detect(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
if (!detector_->Predict(img, &(result->boxes))) {
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
return false;
}
vision::ocr::SortBoxes(result);
return true;
}
bool PPOCRv2::Recognize(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::tuple<std::string, float> rec_result;
if (!recognizer_->Predict(img, &rec_result)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}
result->text.push_back(std::get<0>(rec_result));
result->rec_scores.push_back(std::get<1>(rec_result));
return true;
}
bool PPOCRv2::Classify(cv::Mat* img,
fastdeploy::vision::OCRResult* result) {
std::tuple<int, float> cls_result;
if (!classifier_->Predict(img, &cls_result)) {
FDERROR << "There's error while classifying image in PPOCR." << std::endl;
return false;
}
result->cls_labels.push_back(std::get<0>(cls_result));
result->cls_scores.push_back(std::get<1>(cls_result));
return true;
}
bool PPOCRv2::Predict(cv::Mat* img, bool PPOCRv2::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) { fastdeploy::vision::OCRResult* result) {
result->Clear(); std::vector<fastdeploy::vision::OCRResult> batch_result(1);
if (nullptr != detector_ && !Detect(img, result)) { BatchPredict({*img},&batch_result);
FDERROR << "Failed to detect image." << std::endl; *result = std::move(batch_result[0]);
return false;
}
// Get croped images by detection result
std::vector<cv::Mat> image_list;
for (size_t i = 0; i < result->boxes.size(); ++i) {
auto crop_im = vision::ocr::GetRotateCropImage(*img, (result->boxes)[i]);
image_list.push_back(crop_im);
}
if (result->boxes.size() == 0) {
image_list.push_back(*img);
}
for (size_t i = 0; i < image_list.size(); ++i) {
if (nullptr != classifier_ && !Classify(&(image_list[i]), result)) {
FDERROR << "Failed to classify croped image of index " << i << "." << std::endl;
return false;
}
if (nullptr != classifier_ && result->cls_labels[i] % 2 == 1 && result->cls_scores[i] > classifier_->cls_thresh) {
cv::rotate(image_list[i], image_list[i], 1);
}
if (nullptr != recognizer_ && !Recognize(&(image_list[i]), result)) {
FDERROR << "Failed to recgnize croped image of index " << i << "." << std::endl;
return false;
}
}
return true; return true;
}; };
bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result) {
batch_result->clear();
batch_result->resize(images.size());
std::vector<std::vector<std::array<int, 8>>> batch_boxes(images.size());
if (!detector_->BatchPredict(images, &batch_boxes)) {
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
return false;
}
for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
(*batch_result)[i_batch].boxes = batch_boxes[i_batch];
}
for(int i_batch = 0; i_batch < images.size(); ++i_batch) {
fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch];
// Get croped images by detection result
const std::vector<std::array<int, 8>>& boxes = ocr_result.boxes;
const cv::Mat& img = images[i_batch];
std::vector<cv::Mat> image_list;
if (boxes.size() == 0) {
image_list.emplace_back(img);
}else{
image_list.resize(boxes.size());
for (size_t i_box = 0; i_box < boxes.size(); ++i_box) {
image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]);
}
}
std::vector<int32_t>* cls_labels_ptr = &ocr_result.cls_labels;
std::vector<float>* cls_scores_ptr = &ocr_result.cls_scores;
std::vector<std::string>* text_ptr = &ocr_result.text;
std::vector<float>* rec_scores_ptr = &ocr_result.rec_scores;
if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}else{
for (size_t i_img = 0; i_img < image_list.size(); ++i_img) {
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
cv::rotate(image_list[i_img], image_list[i_img], 1);
}
}
}
if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) {
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
return false;
}
}
return true;
}
} // namesapce pipeline } // namesapce pipeline
} // namespace fastdeploy } // namespace fastdeploy

13
fastdeploy/vision/ocr/ppocr/ppocr_v2.h Normal file → Executable file
View File

@@ -59,6 +59,14 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
* \return true if the prediction successed, otherwise false. * \return true if the prediction successed, otherwise false.
*/ */
virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result); virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/** \brief BatchPredict the input image and get OCR result.
*
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] batch_result The output list of OCR result will be writen to this structure.
* \return true if the prediction successed, otherwise false.
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result);
bool Initialized() const override; bool Initialized() const override;
protected: protected:
@@ -66,11 +74,6 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
fastdeploy::vision::ocr::Classifier* classifier_ = nullptr; fastdeploy::vision::ocr::Classifier* classifier_ = nullptr;
fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr; fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr;
/// Launch the detection process in OCR. /// Launch the detection process in OCR.
virtual bool Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/// Launch the recognition process in OCR.
virtual bool Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
/// Launch the classification process in OCR.
virtual bool Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
}; };
namespace application { namespace application {

4
fastdeploy/vision/ocr/ppocr/ppocr_v3.h Normal file → Executable file
View File

@@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model) fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, cls_model, rec_model) { : PPOCRv2(det_model, cls_model, rec_model) {
// The only difference between v2 and v3 // The only difference between v2 and v3
recognizer_->rec_image_shape[1] = 48; recognizer_->preprocessor_.rec_image_shape_[1] = 48;
} }
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively. /** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
* *
@@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
fastdeploy::vision::ocr::Recognizer* rec_model) fastdeploy::vision::ocr::Recognizer* rec_model)
: PPOCRv2(det_model, rec_model) { : PPOCRv2(det_model, rec_model) {
// The only difference between v2 and v3 // The only difference between v2 and v3
recognizer_->rec_image_shape[1] = 48; recognizer_->preprocessor_.rec_image_shape_[1] = 48;
} }
}; };

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
std::vector<std::string> ReadDict(const std::string& path) {
std::ifstream in(path);
FDASSERT(in, "Cannot open file %s to read.", path.c_str());
std::string line;
std::vector<std::string> m_vec;
while (getline(in, line)) {
m_vec.push_back(line);
}
m_vec.insert(m_vec.begin(), "#"); // blank char for ctc
m_vec.push_back(" ");
return m_vec;
}
RecognizerPostprocessor::RecognizerPostprocessor(){
initialized_ = true;
}
RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) {
// init label_lsit
label_list_ = ReadDict(label_path);
initialized_ = true;
}
bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data,
const std::vector<int64_t>& output_shape,
std::string* text, float* rec_score) {
std::string& str_res = *text;
float& score = *rec_score;
score = 0.f;
int argmax_idx;
int last_index = 0;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < output_shape[1]; n++) {
argmax_idx = int(
std::distance(&out_data[n * output_shape[2]],
std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]])));
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
if(argmax_idx > label_list_.size()) {
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
<< label_list_.size() << ". Please check the label file!" << std::endl;
return false;
}
str_res += label_list_[argmax_idx];
}
last_index = argmax_idx;
}
score /= (count + 1e-6);
if (count == 0 || std::isnan(score)) {
score = 0.f;
}
return true;
}
bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores) {
if (!initialized_) {
FDERROR << "Postprocessor is not initialized." << std::endl;
return false;
}
// Recognizer have only 1 output tensor.
const FDTensor& tensor = tensors[0];
// For Recognizer, the output tensor shape = [batch, ?, 6625]
size_t batch = tensor.shape[0];
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
texts->resize(batch);
rec_scores->resize(batch);
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
for (int i_batch = 0; i_batch < batch; ++i_batch) {
if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) {
return false;
}
tensor_data = tensor_data + length;
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Postprocessor object for Recognizer serials model.
*/
class FASTDEPLOY_DECL RecognizerPostprocessor {
public:
RecognizerPostprocessor();
/** \brief Create a postprocessor instance for Recognizer serials model
*
* \param[in] label_path The path of label_dict
*/
explicit RecognizerPostprocessor(const std::string& label_path);
/** \brief Process the result of runtime and fill to ClassifyResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] texts The output result of recognizer
* \param[in] rec_scores The output result of recognizer
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<std::string>* texts, std::vector<float>* rec_scores);
private:
bool SingleBatchPostprocessor(const float* out_data,
const std::vector<int64_t>& output_shape,
std::string* text, float* rec_score);
bool initialized_ = false;
std::vector<std::string> label_list_;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
#include "fastdeploy/function/concat.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
RecognizerPreprocessor::RecognizerPreprocessor() {
initialized_ = true;
}
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
const std::vector<int>& rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(imgH * max_wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW) {
resize_w = imgW;
}else{
resize_w = int(ceilf(imgH * ratio));
}
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {0, 0, 0};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
}
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
int imgH = rec_image_shape_[1];
int imgW = rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH;
float ori_wh_ratio;
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
}
for (size_t i = 0; i < images->size(); ++i) {
FDMat* mat = &(images->at(i));
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_);
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
/*
Normalize::Run(mat, mean_, scale_, is_scale_);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
*/
}
// Only have 1 output Tensor.
outputs->resize(1);
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
(*images)[i].ShareWithTensor(&(tensors[i]));
tensors[i].ExpandDim(0);
}
if (tensors.size() == 1) {
(*outputs)[0] = std::move(tensors[0]);
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
return true;
}
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace ocr {
/*! @brief Preprocessor object for PaddleClas serials model.
*/
class FASTDEPLOY_DECL RecognizerPreprocessor {
public:
/** \brief Create a preprocessor instance for PaddleClas serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
*/
RecognizerPreprocessor();
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \return true if the preprocess successed, otherwise false
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
std::vector<int> rec_image_shape_ = {3, 48, 320};
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
bool is_scale_ = true;
private:
bool initialized_ = false;
};
} // namespace ocr
} // namespace vision
} // namespace fastdeploy

150
fastdeploy/vision/ocr/ppocr/recognizer.cc Normal file → Executable file
View File

@@ -20,29 +20,13 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
std::vector<std::string> ReadDict(const std::string& path) {
std::ifstream in(path);
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such label file: " << path << ", exit the program..."
<< std::endl;
exit(1);
}
return m_vec;
}
Recognizer::Recognizer() {} Recognizer::Recognizer() {}
Recognizer::Recognizer(const std::string& model_file, Recognizer::Recognizer(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
const std::string& label_path, const std::string& label_path,
const RuntimeOption& custom_option, const RuntimeOption& custom_option,
const ModelFormat& model_format) { const ModelFormat& model_format):postprocessor_(label_path) {
if (model_format == ModelFormat::ONNX) { if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::ORT, valid_cpu_backends = {Backend::ORT,
Backend::OPENVINO}; Backend::OPENVINO};
@@ -56,27 +40,11 @@ Recognizer::Recognizer(const std::string& model_file,
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
runtime_option.model_file = model_file; runtime_option.model_file = model_file;
runtime_option.params_file = params_file; runtime_option.params_file = params_file;
initialized = Initialize(); initialized = Initialize();
// init label_lsit
label_list = ReadDict(label_path);
label_list.insert(label_list.begin(), "#"); // blank char for ctc
label_list.push_back(" ");
} }
// Init // Init
bool Recognizer::Initialize() { bool Recognizer::Initialize() {
// pre&post process parameters
rec_batch_num = 1;
rec_img_h = 48;
rec_img_w = 320;
rec_image_shape = {3, rec_img_h, rec_img_w};
mean = {0.5f, 0.5f, 0.5f};
scale = {0.5f, 0.5f, 0.5f};
is_scale = true;
if (!InitRuntime()) { if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl; FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false; return false;
@@ -85,119 +53,23 @@ bool Recognizer::Initialize() {
return true; return true;
} }
void OcrRecognizerResizeImage(Mat* mat, const float& wh_ratio, bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
const std::vector<int>& rec_image_shape) { std::vector<std::string>* texts, std::vector<float>* rec_scores) {
int imgC, imgH, imgW; std::vector<FDMat> fd_images = WrapMat(images);
imgC = rec_image_shape[0]; if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
imgH = rec_image_shape[1]; FDERROR << "Failed to preprocess the input image." << std::endl;
imgW = rec_image_shape[2];
imgW = int(imgH * wh_ratio);
float ratio = float(mat->Width()) / float(mat->Height());
int resize_w;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
Resize::Run(mat, resize_w, imgH);
std::vector<float> value = {127, 127, 127};
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
}
bool Recognizer::Preprocess(Mat* mat, FDTensor* output,
const std::vector<int>& rec_image_shape) {
int imgH = rec_image_shape[1];
int imgW = rec_image_shape[2];
float wh_ratio = imgW * 1.0 / imgH;
float ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
wh_ratio = std::max(wh_ratio, ori_wh_ratio);
OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape);
Normalize::Run(mat, mean, scale, true);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool Recognizer::Postprocess(FDTensor& infer_result,
std::tuple<std::string, float>* rec_result) {
std::vector<int64_t> output_shape = infer_result.shape;
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
float* out_data = static_cast<float*>(infer_result.Data());
std::string str_res;
int argmax_idx;
int last_index = 0;
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < output_shape[1]; n++) {
argmax_idx = int(
std::distance(&out_data[n * output_shape[2]],
std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]])));
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
&out_data[(n + 1) * output_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
if(argmax_idx > label_list.size()){
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
<< label_list.size() << ". Please check the label file!" << std::endl;
return false; return false;
} }
str_res += label_list[argmax_idx]; reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
} if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
last_index = argmax_idx; FDERROR << "Failed to inference by runtime." << std::endl;
}
score /= (count + 1e-6);
if (count == 0 || std::isnan(score)) {
score = 0.f;
}
std::get<0>(*rec_result) = str_res;
std::get<1>(*rec_result) = score;
return true;
}
bool Recognizer::Predict(cv::Mat* img,
std::tuple<std::string, float>* rec_result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false; return false;
} }
input_tensors[0].name = InputInfoOfRuntime(0).name; if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) {
std::vector<FDTensor> output_tensors; FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false; return false;
} }
if (!Postprocess(output_tensors[0], rec_result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true; return true;
} }

31
fastdeploy/vision/ocr/ppocr/recognizer.h Normal file → Executable file
View File

@@ -17,6 +17,8 @@
#include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -43,35 +45,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_rec"; } std::string ModelName() const { return "ppocr/ocr_rec"; }
/** \brief Predict the input image and get OCR recognition model result. /** \brief BatchPredict the input image and get OCR recognition model result.
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
* \param[in] rec_result The output of OCR recognition model result will be writen to this structure. * \param[in] rec_results The output of OCR recognition model result will be writen to this structure.
* \return true if the prediction is successed, otherwise false. * \return true if the prediction is successed, otherwise false.
*/ */
virtual bool Predict(cv::Mat* img, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::tuple<std::string, float>* rec_result); std::vector<std::string>* texts, std::vector<float>* rec_scores);
// Pre & Post parameters RecognizerPreprocessor preprocessor_;
std::vector<std::string> label_list; RecognizerPostprocessor postprocessor_;
int rec_batch_num;
int rec_img_h;
int rec_img_w;
std::vector<int> rec_image_shape;
std::vector<float> mean;
std::vector<float> scale;
bool is_scale;
private: private:
bool Initialize(); bool Initialize();
/// Preprocess the input data, and set the preprocessed results to `outputs`
bool Preprocess(Mat* img, FDTensor* outputs,
const std::vector<int>& rec_image_shape);
/*! @brief Postprocess the inferenced results, and set the final result to `rec_result`
*/
bool Postprocess(FDTensor& infer_result,
std::tuple<std::string, float>* rec_result);
}; };
} // namespace ocr } // namespace ocr

10
fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc Normal file → Executable file
View File

@@ -318,10 +318,12 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
} }
std::vector<std::vector<std::vector<int>>> PostProcessor::FilterTagDetRes( std::vector<std::vector<std::vector<int>>> PostProcessor::FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h, std::vector<std::vector<std::vector<int>>> boxes,
float ratio_w, const std::map<std::string, std::array<float, 2>> &im_info) { const std::array<int,4>& det_img_info) {
int oriimg_h = im_info.at("input_shape")[0]; int oriimg_w = det_img_info[0];
int oriimg_w = im_info.at("input_shape")[1]; int oriimg_h = det_img_info[1];
float ratio_w = float(det_img_info[2])/float(oriimg_w);
float ratio_h = float(det_img_info[3])/float(oriimg_h);
std::vector<std::vector<std::vector<int>>> root_points; std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < boxes.size(); n++) { for (int n = 0; n < boxes.size(); n++) {

View File

@@ -57,9 +57,8 @@ class PostProcessor {
const float &det_db_unclip_ratio, const std::string &det_db_score_mode); const float &det_db_unclip_ratio, const std::string &det_db_score_mode);
std::vector<std::vector<std::vector<int>>> FilterTagDetRes( std::vector<std::vector<std::vector<int>>> FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h, std::vector<std::vector<std::vector<int>>> boxes,
float ratio_w, const std::array<int, 4>& det_img_info);
const std::map<std::string, std::array<float, 2>> &im_info);
private: private:
static bool XsortInt(std::vector<int> a, std::vector<int> b); static bool XsortInt(std::vector<int> a, std::vector<int> b);

View File

@@ -28,10 +28,10 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace ocr { namespace ocr {
cv::Mat GetRotateCropImage(const cv::Mat& srcimage, FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
const std::array<int, 8>& box); const std::array<int, 8>& box);
void SortBoxes(OCRResult* result); FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes);
} // namespace ocr } // namespace ocr
} // namespace vision } // namespace vision

View File

@@ -29,17 +29,17 @@ bool CompareBox(const std::array<int, 8>& result1,
} }
} }
void SortBoxes(OCRResult* result) { void SortBoxes(std::vector<std::array<int, 8>>* boxes) {
std::sort(result->boxes.begin(), result->boxes.end(), CompareBox); std::sort(boxes->begin(), boxes->end(), CompareBox);
if (result->boxes.size() == 0) { if (boxes->size() == 0) {
return; return;
} }
for (int i = 0; i < result->boxes.size() - 1; i++) { for (int i = 0; i < boxes->size() - 1; i++) {
if (abs(result->boxes[i + 1][1] - result->boxes[i][1]) < 10 && if (abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 &&
(result->boxes[i + 1][0] < result->boxes[i][0])) { ((*boxes)[i + 1][0] < (*boxes)[i][0])) {
std::swap(result->boxes[i], result->boxes[i + 1]); std::swap((*boxes)[i], (*boxes)[i + 1]);
} }
} }
} }

8
python/fastdeploy/vision/ocr/__init__.py Normal file → Executable file
View File

@@ -13,10 +13,4 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from .ppocr import PPOCRv3 from .ppocr import *
from .ppocr import PPOCRv2
from .ppocr import PPOCRSystemv3
from .ppocr import PPOCRSystemv2
from .ppocr import DBDetector
from .ppocr import Classifier
from .ppocr import Recognizer

172
python/fastdeploy/vision/ocr/ppocr/__init__.py Normal file → Executable file
View File

@@ -41,40 +41,11 @@ class DBDetector(FastDeployModel):
assert self.initialized, "DBDetector initialize failed." assert self.initialized, "DBDetector initialize failed."
# 一些跟DBDetector模型有关的属性封装 # 一些跟DBDetector模型有关的属性封装
@property '''
def max_side_len(self):
return self._model.max_side_len
@property @property
def det_db_thresh(self): def det_db_thresh(self):
return self._model.det_db_thresh return self._model.det_db_thresh
@property
def det_db_box_thresh(self):
return self._model.det_db_box_thresh
@property
def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio
@property
def det_db_score_mode(self):
return self._model.det_db_score_mode
@property
def use_dilation(self):
return self._model.use_dilation
@property
def is_scale(self):
return self._model.max_wh
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.max_side_len = value
@det_db_thresh.setter @det_db_thresh.setter
def det_db_thresh(self, value): def det_db_thresh(self, value):
assert isinstance( assert isinstance(
@@ -82,6 +53,10 @@ class DBDetector(FastDeployModel):
float), "The value to set `det_db_thresh` must be type of float." float), "The value to set `det_db_thresh` must be type of float."
self._model.det_db_thresh = value self._model.det_db_thresh = value
@property
def det_db_box_thresh(self):
return self._model.det_db_box_thresh
@det_db_box_thresh.setter @det_db_box_thresh.setter
def det_db_box_thresh(self, value): def det_db_box_thresh(self, value):
assert isinstance( assert isinstance(
@@ -89,6 +64,10 @@ class DBDetector(FastDeployModel):
), "The value to set `det_db_box_thresh` must be type of float." ), "The value to set `det_db_box_thresh` must be type of float."
self._model.det_db_box_thresh = value self._model.det_db_box_thresh = value
@property
def det_db_unclip_ratio(self):
return self._model.det_db_unclip_ratio
@det_db_unclip_ratio.setter @det_db_unclip_ratio.setter
def det_db_unclip_ratio(self, value): def det_db_unclip_ratio(self, value):
assert isinstance( assert isinstance(
@@ -96,6 +75,10 @@ class DBDetector(FastDeployModel):
), "The value to set `det_db_unclip_ratio` must be type of float." ), "The value to set `det_db_unclip_ratio` must be type of float."
self._model.det_db_unclip_ratio = value self._model.det_db_unclip_ratio = value
@property
def det_db_score_mode(self):
return self._model.det_db_score_mode
@det_db_score_mode.setter @det_db_score_mode.setter
def det_db_score_mode(self, value): def det_db_score_mode(self, value):
assert isinstance( assert isinstance(
@@ -103,6 +86,10 @@ class DBDetector(FastDeployModel):
str), "The value to set `det_db_score_mode` must be type of str." str), "The value to set `det_db_score_mode` must be type of str."
self._model.det_db_score_mode = value self._model.det_db_score_mode = value
@property
def use_dilation(self):
return self._model.use_dilation
@use_dilation.setter @use_dilation.setter
def use_dilation(self, value): def use_dilation(self, value):
assert isinstance( assert isinstance(
@@ -110,11 +97,26 @@ class DBDetector(FastDeployModel):
bool), "The value to set `use_dilation` must be type of bool." bool), "The value to set `use_dilation` must be type of bool."
self._model.use_dilation = value self._model.use_dilation = value
@property
def max_side_len(self):
return self._model.max_side_len
@max_side_len.setter
def max_side_len(self, value):
assert isinstance(
value, int), "The value to set `max_side_len` must be type of int."
self._model.max_side_len = value
@property
def is_scale(self):
return self._model.max_wh
@is_scale.setter @is_scale.setter
def is_scale(self, value): def is_scale(self, value):
assert isinstance( assert isinstance(
value, bool), "The value to set `is_scale` must be type of bool." value, bool), "The value to set `is_scale` must be type of bool."
self._model.is_scale = value self._model.is_scale = value
'''
class Classifier(FastDeployModel): class Classifier(FastDeployModel):
@@ -139,6 +141,7 @@ class Classifier(FastDeployModel):
model_file, params_file, self._runtime_option, model_format) model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "Classifier initialize failed." assert self.initialized, "Classifier initialize failed."
'''
@property @property
def cls_thresh(self): def cls_thresh(self):
return self._model.cls_thresh return self._model.cls_thresh
@@ -170,6 +173,7 @@ class Classifier(FastDeployModel):
value, value,
int), "The value to set `cls_batch_num` must be type of int." int), "The value to set `cls_batch_num` must be type of int."
self._model.cls_batch_num = value self._model.cls_batch_num = value
'''
class Recognizer(FastDeployModel): class Recognizer(FastDeployModel):
@@ -197,6 +201,7 @@ class Recognizer(FastDeployModel):
model_format) model_format)
assert self.initialized, "Recognizer initialize failed." assert self.initialized, "Recognizer initialize failed."
'''
@property @property
def rec_img_h(self): def rec_img_h(self):
return self._model.rec_img_h return self._model.rec_img_h
@@ -227,6 +232,7 @@ class Recognizer(FastDeployModel):
value, value,
int), "The value to set `rec_batch_num` must be type of int." int), "The value to set `rec_batch_num` must be type of int."
self._model.rec_batch_num = value self._model.rec_batch_num = value
'''
class PPOCRv3(FastDeployModel): class PPOCRv3(FastDeployModel):
@@ -253,6 +259,14 @@ class PPOCRv3(FastDeployModel):
""" """
return self.system.predict(input_image) return self.system.predict(input_image)
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult
"""
return self.system.batch_predict(images)
class PPOCRSystemv3(PPOCRv3): class PPOCRSystemv3(PPOCRv3):
def __init__(self, det_model=None, cls_model=None, rec_model=None): def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -289,6 +303,14 @@ class PPOCRv2(FastDeployModel):
""" """
return self.system.predict(input_image) return self.system.predict(input_image)
def batch_predict(self, images):
"""Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult
"""
return self.system.batch_predict(images)
class PPOCRSystemv2(PPOCRv2): class PPOCRSystemv2(PPOCRv2):
def __init__(self, det_model=None, cls_model=None, rec_model=None): def __init__(self, det_model=None, cls_model=None, rec_model=None):
@@ -299,3 +321,93 @@ class PPOCRSystemv2(PPOCRv2):
def predict(self, input_image): def predict(self, input_image):
return super(PPOCRSystemv2, self).predict(input_image) return super(PPOCRSystemv2, self).predict(input_image)
class DBDetectorPreprocessor:
def __init__(self):
"""Create a preprocessor for DBDetectorModel
"""
self._preprocessor = C.vision.ocr.DBDetectorPreprocessor()
def run(self, input_ims):
"""Preprocess input images for DBDetectorModel
:param: input_ims: (list of numpy.ndarray) The input image
:return: pair(list of FDTensor, list of std::array<int, 4>)
"""
return self._preprocessor.run(input_ims)
class DBDetectorPostprocessor:
def __init__(self):
"""Create a postprocessor for DBDetectorModel
"""
self._postprocessor = C.vision.ocr.DBDetectorPostprocessor()
def run(self, runtime_results, batch_det_img_info):
"""Postprocess the runtime results for DBDetectorModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:param: batch_det_img_info: (list of std::array<int, 4>)The output of det_preprocessor
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, batch_det_img_info)
class RecognizerPreprocessor:
def __init__(self):
"""Create a preprocessor for RecognizerModel
"""
self._preprocessor = C.vision.ocr.RecognizerPreprocessor()
def run(self, input_ims):
"""Preprocess input images for RecognizerModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class RecognizerPostprocessor:
def __init__(self, label_path):
"""Create a postprocessor for RecognizerModel
:param label_path: (str)Path of label file
"""
self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path)
def run(self, runtime_results):
"""Postprocess the runtime results for RecognizerModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class ClassifierPreprocessor:
def __init__(self):
"""Create a preprocessor for ClassifierModel
"""
self._preprocessor = C.vision.ocr.ClassifierPreprocessor()
def run(self, input_ims):
"""Preprocess input images for ClassifierModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class ClassifierPostprocessor:
def __init__(self):
"""Create a postprocessor for ClassifierModel
"""
self._postprocessor = C.vision.ocr.ClassifierPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for ClassifierModel
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
def sort_boxes(boxes):
return C.vision.ocr.sort_boxes(boxes)