mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] change ocr pre and post (#568)
* change ocr pre and post * add pybind * change ocr * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix copy bug * fix code style * fix bug * add new function * fix windows ci bug
This commit is contained in:
@@ -22,6 +22,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#ifdef FASTDEPLOY_LIB
|
#ifdef FASTDEPLOY_LIB
|
||||||
|
@@ -48,6 +48,7 @@
|
|||||||
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
|
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
|
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
|
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||||
|
93
fastdeploy/vision/ocr/ppocr/classifier.cc
Normal file → Executable file
93
fastdeploy/vision/ocr/ppocr/classifier.cc
Normal file → Executable file
@@ -41,16 +41,7 @@ Classifier::Classifier(const std::string& model_file,
|
|||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init
|
|
||||||
bool Classifier::Initialize() {
|
bool Classifier::Initialize() {
|
||||||
// pre&post process parameters
|
|
||||||
cls_thresh = 0.9;
|
|
||||||
cls_image_shape = {3, 48, 192};
|
|
||||||
cls_batch_num = 1;
|
|
||||||
mean = {0.5f, 0.5f, 0.5f};
|
|
||||||
scale = {0.5f, 0.5f, 0.5f};
|
|
||||||
is_scale = true;
|
|
||||||
|
|
||||||
if (!InitRuntime()) {
|
if (!InitRuntime()) {
|
||||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -59,85 +50,23 @@ bool Classifier::Initialize() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OcrClassifierResizeImage(Mat* mat,
|
bool Classifier::BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
const std::vector<int>& rec_image_shape) {
|
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores) {
|
||||||
int imgC = rec_image_shape[0];
|
std::vector<FDMat> fd_images = WrapMat(images);
|
||||||
int imgH = rec_image_shape[1];
|
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||||
int imgW = rec_image_shape[2];
|
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||||
|
return false;
|
||||||
float ratio = float(mat->Width()) / float(mat->Height());
|
|
||||||
|
|
||||||
int resize_w;
|
|
||||||
if (ceilf(imgH * ratio) > imgW)
|
|
||||||
resize_w = imgW;
|
|
||||||
else
|
|
||||||
resize_w = int(ceilf(imgH * ratio));
|
|
||||||
|
|
||||||
Resize::Run(mat, resize_w, imgH);
|
|
||||||
|
|
||||||
std::vector<float> value = {0, 0, 0};
|
|
||||||
if (resize_w < imgW) {
|
|
||||||
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
|
|
||||||
}
|
}
|
||||||
}
|
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||||
|
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||||
bool Classifier::Preprocess(Mat* mat, FDTensor* output) {
|
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||||
// 1. cls resizes
|
|
||||||
// 2. normalize
|
|
||||||
// 3. batch_permute
|
|
||||||
OcrClassifierResizeImage(mat, cls_image_shape);
|
|
||||||
|
|
||||||
Normalize::Run(mat, mean, scale, true);
|
|
||||||
|
|
||||||
HWC2CHW::Run(mat);
|
|
||||||
Cast::Run(mat, "float");
|
|
||||||
|
|
||||||
mat->ShareWithTensor(output);
|
|
||||||
output->shape.insert(output->shape.begin(), 1);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Classifier::Postprocess(FDTensor& infer_result,
|
|
||||||
std::tuple<int, float>* cls_result) {
|
|
||||||
std::vector<int64_t> output_shape = infer_result.shape;
|
|
||||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
|
||||||
|
|
||||||
float* out_data = static_cast<float*>(infer_result.Data());
|
|
||||||
|
|
||||||
int label = std::distance(
|
|
||||||
&out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]]));
|
|
||||||
|
|
||||||
float score =
|
|
||||||
float(*std::max_element(&out_data[0], &out_data[output_shape[1]]));
|
|
||||||
|
|
||||||
std::get<0>(*cls_result) = label;
|
|
||||||
std::get<1>(*cls_result) = score;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Classifier::Predict(cv::Mat* img, std::tuple<int, float>* cls_result) {
|
|
||||||
Mat mat(*img);
|
|
||||||
std::vector<FDTensor> input_tensors(1);
|
|
||||||
|
|
||||||
if (!Preprocess(&mat, &input_tensors[0])) {
|
|
||||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) {
|
||||||
std::vector<FDTensor> output_tensors;
|
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], cls_result)) {
|
|
||||||
FDERROR << "Failed to post process." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
27
fastdeploy/vision/ocr/ppocr/classifier.h
Normal file → Executable file
27
fastdeploy/vision/ocr/ppocr/classifier.h
Normal file → Executable file
@@ -17,6 +17,8 @@
|
|||||||
#include "fastdeploy/vision/common/processors/transform.h"
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
#include "fastdeploy/vision/common/result.h"
|
#include "fastdeploy/vision/common/result.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -41,29 +43,22 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
|
|||||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||||
/// Get model's name
|
/// Get model's name
|
||||||
std::string ModelName() const { return "ppocr/ocr_cls"; }
|
std::string ModelName() const { return "ppocr/ocr_cls"; }
|
||||||
/** \brief Predict the input image and get OCR classification model result.
|
|
||||||
|
/** \brief BatchPredict the input image and get OCR classification model cls_result.
|
||||||
*
|
*
|
||||||
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
* \param[in] result The output of OCR classification model result will be writen to this structure.
|
* \param[in] cls_results The output of OCR classification model cls_result will be writen to this structure.
|
||||||
* \return true if the prediction is successed, otherwise false.
|
* \return true if the prediction is successed, otherwise false.
|
||||||
*/
|
*/
|
||||||
virtual bool Predict(cv::Mat* img, std::tuple<int, float>* result);
|
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<int32_t>* cls_labels,
|
||||||
|
std::vector<float>* cls_scores);
|
||||||
|
|
||||||
// Pre & Post parameters
|
ClassifierPreprocessor preprocessor_;
|
||||||
float cls_thresh;
|
ClassifierPostprocessor postprocessor_;
|
||||||
std::vector<int> cls_image_shape;
|
|
||||||
int cls_batch_num;
|
|
||||||
|
|
||||||
std::vector<float> mean;
|
|
||||||
std::vector<float> scale;
|
|
||||||
bool is_scale;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool Initialize();
|
bool Initialize();
|
||||||
/// Preprocess the input data, and set the preprocessed results to `outputs`
|
|
||||||
bool Preprocess(Mat* img, FDTensor* output);
|
|
||||||
/// Postprocess the inferenced results, and set the final result to `result`
|
|
||||||
bool Postprocess(FDTensor& infer_result, std::tuple<int, float>* result);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ocr
|
} // namespace ocr
|
||||||
|
65
fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc
Normal file
65
fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
ClassifierPostprocessor::ClassifierPostprocessor() {
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) {
|
||||||
|
|
||||||
|
*cls_label = std::distance(
|
||||||
|
&out_data[0], std::max_element(&out_data[0], &out_data[length]));
|
||||||
|
|
||||||
|
*cls_score =
|
||||||
|
float(*std::max_element(&out_data[0], &out_data[length]));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ClassifierPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<int32_t>* cls_labels,
|
||||||
|
std::vector<float>* cls_scores) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "Postprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Classifier have only 1 output tensor.
|
||||||
|
const FDTensor& tensor = tensors[0];
|
||||||
|
|
||||||
|
// For Classifier, the output tensor shape = [batch,2]
|
||||||
|
size_t batch = tensor.shape[0];
|
||||||
|
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
|
||||||
|
|
||||||
|
cls_labels->resize(batch);
|
||||||
|
cls_scores->resize(batch);
|
||||||
|
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
|
||||||
|
|
||||||
|
for (int i_batch = 0; i_batch < batch; ++i_batch) {
|
||||||
|
if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false;
|
||||||
|
tensor_data = tensor_data + length;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace classification
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
51
fastdeploy/vision/ocr/ppocr/cls_postprocessor.h
Normal file
51
fastdeploy/vision/ocr/ppocr/cls_postprocessor.h
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Postprocessor object for Classifier serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL ClassifierPostprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a postprocessor instance for Classifier serials model
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ClassifierPostprocessor();
|
||||||
|
|
||||||
|
/** \brief Process the result of runtime and fill to ClassifyResult structure
|
||||||
|
*
|
||||||
|
* \param[in] tensors The inference result from runtime
|
||||||
|
* \param[in] cls_labels The output result of classification
|
||||||
|
* \param[in] cls_scores The output result of classification
|
||||||
|
* \return true if the postprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<int32_t>* cls_labels, std::vector<float>* cls_scores);
|
||||||
|
|
||||||
|
float cls_thresh_ = 0.9;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
88
fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc
Normal file
88
fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
#include "fastdeploy/function/concat.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
ClassifierPreprocessor::ClassifierPreprocessor() {
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OcrClassifierResizeImage(FDMat* mat,
|
||||||
|
const std::vector<int>& cls_image_shape) {
|
||||||
|
int imgC = cls_image_shape[0];
|
||||||
|
int imgH = cls_image_shape[1];
|
||||||
|
int imgW = cls_image_shape[2];
|
||||||
|
|
||||||
|
float ratio = float(mat->Width()) / float(mat->Height());
|
||||||
|
|
||||||
|
int resize_w;
|
||||||
|
if (ceilf(imgH * ratio) > imgW)
|
||||||
|
resize_w = imgW;
|
||||||
|
else
|
||||||
|
resize_w = int(ceilf(imgH * ratio));
|
||||||
|
|
||||||
|
Resize::Run(mat, resize_w, imgH);
|
||||||
|
|
||||||
|
std::vector<float> value = {0, 0, 0};
|
||||||
|
if (resize_w < imgW) {
|
||||||
|
Pad::Run(mat, 0, 0, 0, imgW - resize_w, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ClassifierPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (images->size() == 0) {
|
||||||
|
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
FDMat* mat = &(images->at(i));
|
||||||
|
OcrClassifierResizeImage(mat, cls_image_shape_);
|
||||||
|
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
/*
|
||||||
|
Normalize::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
HWC2CHW::Run(mat);
|
||||||
|
Cast::Run(mat, "float");
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
// Only have 1 output Tensor.
|
||||||
|
outputs->resize(1);
|
||||||
|
// Concat all the preprocessed data to a batch tensor
|
||||||
|
std::vector<FDTensor> tensors(images->size());
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
(*images)[i].ShareWithTensor(&(tensors[i]));
|
||||||
|
tensors[i].ExpandDim(0);
|
||||||
|
}
|
||||||
|
if (tensors.size() == 1) {
|
||||||
|
(*outputs)[0] = std::move(tensors[0]);
|
||||||
|
} else {
|
||||||
|
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
51
fastdeploy/vision/ocr/ppocr/cls_preprocessor.h
Normal file
51
fastdeploy/vision/ocr/ppocr/cls_preprocessor.h
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Preprocessor object for Classifier serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL ClassifierPreprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a preprocessor instance for Classifier serials model
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ClassifierPreprocessor();
|
||||||
|
|
||||||
|
/** \brief Process the input image and prepare input tensors for runtime
|
||||||
|
*
|
||||||
|
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||||
|
* \param[in] outputs The output tensors which will feed in runtime
|
||||||
|
* \return true if the preprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||||
|
|
||||||
|
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||||
|
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
|
||||||
|
bool is_scale_ = true;
|
||||||
|
std::vector<int> cls_image_shape_ = {3, 48, 192};
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
156
fastdeploy/vision/ocr/ppocr/dbdetector.cc
Normal file → Executable file
156
fastdeploy/vision/ocr/ppocr/dbdetector.cc
Normal file → Executable file
@@ -43,158 +43,50 @@ DBDetector::DBDetector(const std::string& model_file,
|
|||||||
|
|
||||||
// Init
|
// Init
|
||||||
bool DBDetector::Initialize() {
|
bool DBDetector::Initialize() {
|
||||||
// pre&post process parameters
|
|
||||||
max_side_len = 960;
|
|
||||||
|
|
||||||
det_db_thresh = 0.3;
|
|
||||||
det_db_box_thresh = 0.6;
|
|
||||||
det_db_unclip_ratio = 1.5;
|
|
||||||
det_db_score_mode = "slow";
|
|
||||||
use_dilation = false;
|
|
||||||
|
|
||||||
mean = {0.485f, 0.456f, 0.406f};
|
|
||||||
scale = {0.229f, 0.224f, 0.225f};
|
|
||||||
is_scale = true;
|
|
||||||
|
|
||||||
if (!InitRuntime()) {
|
if (!InitRuntime()) {
|
||||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void OcrDetectorResizeImage(Mat* img, int max_size_len, float* ratio_h,
|
|
||||||
float* ratio_w) {
|
|
||||||
int w = img->Width();
|
|
||||||
int h = img->Height();
|
|
||||||
|
|
||||||
float ratio = 1.f;
|
|
||||||
int max_wh = w >= h ? w : h;
|
|
||||||
if (max_wh > max_size_len) {
|
|
||||||
if (h > w) {
|
|
||||||
ratio = float(max_size_len) / float(h);
|
|
||||||
} else {
|
|
||||||
ratio = float(max_size_len) / float(w);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int resize_h = int(float(h) * ratio);
|
|
||||||
int resize_w = int(float(w) * ratio);
|
|
||||||
|
|
||||||
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
|
|
||||||
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
|
|
||||||
|
|
||||||
Resize::Run(img, resize_w, resize_h);
|
|
||||||
|
|
||||||
*ratio_h = float(resize_h) / float(h);
|
|
||||||
*ratio_w = float(resize_w) / float(w);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool DBDetector::Preprocess(
|
|
||||||
Mat* mat, FDTensor* output,
|
|
||||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
|
||||||
// Resize
|
|
||||||
OcrDetectorResizeImage(mat, max_side_len, &ratio_h, &ratio_w);
|
|
||||||
// Normalize
|
|
||||||
Normalize::Run(mat, mean, scale, true);
|
|
||||||
|
|
||||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
|
||||||
static_cast<float>(mat->Width())};
|
|
||||||
//-CHW
|
|
||||||
HWC2CHW::Run(mat);
|
|
||||||
Cast::Run(mat, "float");
|
|
||||||
|
|
||||||
mat->ShareWithTensor(output);
|
|
||||||
output->shape.insert(output->shape.begin(), 1);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool DBDetector::Postprocess(
|
|
||||||
FDTensor& infer_result, std::vector<std::array<int, 8>>* boxes_result,
|
|
||||||
const std::map<std::string, std::array<float, 2>>& im_info) {
|
|
||||||
std::vector<int64_t> output_shape = infer_result.shape;
|
|
||||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
|
||||||
int n2 = output_shape[2];
|
|
||||||
int n3 = output_shape[3];
|
|
||||||
int n = n2 * n3;
|
|
||||||
|
|
||||||
float* out_data = static_cast<float*>(infer_result.Data());
|
|
||||||
// prepare bitmap
|
|
||||||
std::vector<float> pred(n, 0.0);
|
|
||||||
std::vector<unsigned char> cbuf(n, ' ');
|
|
||||||
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
pred[i] = float(out_data[i]);
|
|
||||||
cbuf[i] = (unsigned char)((out_data[i]) * 255);
|
|
||||||
}
|
|
||||||
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
|
|
||||||
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
|
|
||||||
|
|
||||||
const double threshold = det_db_thresh * 255;
|
|
||||||
const double maxvalue = 255;
|
|
||||||
cv::Mat bit_map;
|
|
||||||
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
|
||||||
if (use_dilation) {
|
|
||||||
cv::Mat dila_ele =
|
|
||||||
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
|
|
||||||
cv::dilate(bit_map, bit_map, dila_ele);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<std::vector<int>>> boxes;
|
|
||||||
|
|
||||||
boxes =
|
|
||||||
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh,
|
|
||||||
det_db_unclip_ratio, det_db_score_mode);
|
|
||||||
|
|
||||||
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info);
|
|
||||||
|
|
||||||
// boxes to boxes_result
|
|
||||||
for (int i = 0; i < boxes.size(); i++) {
|
|
||||||
std::array<int, 8> new_box;
|
|
||||||
int k = 0;
|
|
||||||
for (auto& vec : boxes[i]) {
|
|
||||||
for (auto& e : vec) {
|
|
||||||
new_box[k++] = e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
boxes_result->push_back(new_box);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DBDetector::Predict(cv::Mat* img,
|
bool DBDetector::Predict(cv::Mat* img,
|
||||||
std::vector<std::array<int, 8>>* boxes_result) {
|
std::vector<std::array<int, 8>>* boxes_result) {
|
||||||
Mat mat(*img);
|
if (!Predict(*img, boxes_result)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<FDTensor> input_tensors(1);
|
bool DBDetector::Predict(const cv::Mat& img,
|
||||||
|
std::vector<std::array<int, 8>>* boxes_result) {
|
||||||
|
std::vector<std::vector<std::array<int, 8>>> det_results;
|
||||||
|
if (!BatchPredict({img}, &det_results)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*boxes_result = std::move(det_results[0]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::map<std::string, std::array<float, 2>> im_info;
|
bool DBDetector::BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<std::vector<std::array<int, 8>>>* det_results) {
|
||||||
// Record the shape of image and the shape of preprocessed image
|
std::vector<FDMat> fd_images = WrapMat(images);
|
||||||
im_info["input_shape"] = {static_cast<float>(mat.Height()),
|
std::vector<std::array<int, 4>> batch_det_img_info;
|
||||||
static_cast<float>(mat.Width())};
|
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) {
|
||||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
|
||||||
static_cast<float>(mat.Width())};
|
|
||||||
|
|
||||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
|
||||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||||
std::vector<FDTensor> output_tensors;
|
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], boxes_result, im_info)) {
|
if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) {
|
||||||
FDERROR << "Failed to post process." << std::endl;
|
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
48
fastdeploy/vision/ocr/ppocr/dbdetector.h
Normal file → Executable file
48
fastdeploy/vision/ocr/ppocr/dbdetector.h
Normal file → Executable file
@@ -17,6 +17,8 @@
|
|||||||
#include "fastdeploy/vision/common/processors/transform.h"
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
#include "fastdeploy/vision/common/result.h"
|
#include "fastdeploy/vision/common/result.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -44,40 +46,34 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
|
|||||||
std::string ModelName() const { return "ppocr/ocr_det"; }
|
std::string ModelName() const { return "ppocr/ocr_det"; }
|
||||||
/** \brief Predict the input image and get OCR detection model result.
|
/** \brief Predict the input image and get OCR detection model result.
|
||||||
*
|
*
|
||||||
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
|
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
|
||||||
* \return true if the prediction is successed, otherwise false.
|
* \return true if the prediction is successed, otherwise false.
|
||||||
*/
|
*/
|
||||||
virtual bool Predict(cv::Mat* im,
|
virtual bool Predict(cv::Mat* img,
|
||||||
std::vector<std::array<int, 8>>* boxes_result);
|
std::vector<std::array<int, 8>>* boxes_result);
|
||||||
|
/** \brief Predict the input image and get OCR detection model result.
|
||||||
|
*
|
||||||
|
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
|
* \param[in] boxes_result The output of OCR detection model result will be writen to this structure.
|
||||||
|
* \return true if the prediction is successed, otherwise false.
|
||||||
|
*/
|
||||||
|
virtual bool Predict(const cv::Mat& img,
|
||||||
|
std::vector<std::array<int, 8>>* boxes_result);
|
||||||
|
/** \brief BatchPredict the input image and get OCR detection model result.
|
||||||
|
*
|
||||||
|
* \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
|
* \param[in] det_results The output of OCR detection model result will be writen to this structure.
|
||||||
|
* \return true if the prediction is successed, otherwise false.
|
||||||
|
*/
|
||||||
|
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<std::vector<std::array<int, 8>>>* det_results);
|
||||||
|
|
||||||
// Pre & Post process parameters
|
DBDetectorPreprocessor preprocessor_;
|
||||||
int max_side_len;
|
DBDetectorPostprocessor postprocessor_;
|
||||||
|
|
||||||
float ratio_h{};
|
|
||||||
float ratio_w{};
|
|
||||||
|
|
||||||
double det_db_thresh;
|
|
||||||
double det_db_box_thresh;
|
|
||||||
double det_db_unclip_ratio;
|
|
||||||
std::string det_db_score_mode;
|
|
||||||
bool use_dilation;
|
|
||||||
|
|
||||||
std::vector<float> mean;
|
|
||||||
std::vector<float> scale;
|
|
||||||
bool is_scale;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool Initialize();
|
bool Initialize();
|
||||||
/// Preprocess the input data, and set the preprocessed results to `outputs`
|
|
||||||
bool Preprocess(Mat* mat, FDTensor* outputs,
|
|
||||||
std::map<std::string, std::array<float, 2>>* im_info);
|
|
||||||
/*! @brief Postprocess the inferenced results, and set the final result to `boxes_result`
|
|
||||||
*/
|
|
||||||
bool Postprocess(FDTensor& infer_result,
|
|
||||||
std::vector<std::array<int, 8>>* boxes_result,
|
|
||||||
const std::map<std::string, std::array<float, 2>>& im_info);
|
|
||||||
PostProcessor post_processor_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ocr
|
} // namespace ocr
|
||||||
|
110
fastdeploy/vision/ocr/ppocr/det_postprocessor.cc
Normal file
110
fastdeploy/vision/ocr/ppocr/det_postprocessor.cc
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
DBDetectorPostprocessor::DBDetectorPostprocessor() {
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBDetectorPostprocessor::SingleBatchPostprocessor(
|
||||||
|
const float* out_data,
|
||||||
|
int n2,
|
||||||
|
int n3,
|
||||||
|
const std::array<int,4>& det_img_info,
|
||||||
|
std::vector<std::array<int, 8>>* boxes_result
|
||||||
|
) {
|
||||||
|
int n = n2 * n3;
|
||||||
|
|
||||||
|
// prepare bitmap
|
||||||
|
std::vector<float> pred(n, 0.0);
|
||||||
|
std::vector<unsigned char> cbuf(n, ' ');
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
pred[i] = float(out_data[i]);
|
||||||
|
cbuf[i] = (unsigned char)((out_data[i]) * 255);
|
||||||
|
}
|
||||||
|
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data());
|
||||||
|
cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data());
|
||||||
|
|
||||||
|
const double threshold = det_db_thresh_ * 255;
|
||||||
|
const double maxvalue = 255;
|
||||||
|
cv::Mat bit_map;
|
||||||
|
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
||||||
|
if (use_dilation_) {
|
||||||
|
cv::Mat dila_ele =
|
||||||
|
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
|
||||||
|
cv::dilate(bit_map, bit_map, dila_ele);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<std::vector<int>>> boxes;
|
||||||
|
|
||||||
|
boxes =
|
||||||
|
post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_,
|
||||||
|
det_db_unclip_ratio_, det_db_score_mode_);
|
||||||
|
|
||||||
|
boxes = post_processor_.FilterTagDetRes(boxes, det_img_info);
|
||||||
|
|
||||||
|
// boxes to boxes_result
|
||||||
|
for (int i = 0; i < boxes.size(); i++) {
|
||||||
|
std::array<int, 8> new_box;
|
||||||
|
int k = 0;
|
||||||
|
for (auto& vec : boxes[i]) {
|
||||||
|
for (auto& e : vec) {
|
||||||
|
new_box[k++] = e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boxes_result->push_back(new_box);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBDetectorPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<std::vector<std::array<int, 8>>>* results,
|
||||||
|
const std::vector<std::array<int,4>>& batch_det_img_info) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "Postprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// DBDetector have only 1 output tensor.
|
||||||
|
const FDTensor& tensor = tensors[0];
|
||||||
|
|
||||||
|
// For DBDetector, the output tensor shape = [batch, 1, ?, ?]
|
||||||
|
size_t batch = tensor.shape[0];
|
||||||
|
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
|
||||||
|
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
|
||||||
|
|
||||||
|
results->resize(batch);
|
||||||
|
for (int i_batch = 0; i_batch < batch; ++i_batch) {
|
||||||
|
if(!SingleBatchPostprocessor(tensor_data,
|
||||||
|
tensor.shape[2],
|
||||||
|
tensor.shape[3],
|
||||||
|
batch_det_img_info[i_batch],
|
||||||
|
&results->at(i_batch)
|
||||||
|
))return false;
|
||||||
|
tensor_data = tensor_data + length;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
62
fastdeploy/vision/ocr/ppocr/det_postprocessor.h
Normal file
62
fastdeploy/vision/ocr/ppocr/det_postprocessor.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Postprocessor object for DBDetector serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL DBDetectorPostprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a postprocessor instance for DBDetector serials model
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
DBDetectorPostprocessor();
|
||||||
|
|
||||||
|
/** \brief Process the result of runtime and fill to results structure
|
||||||
|
*
|
||||||
|
* \param[in] tensors The inference result from runtime
|
||||||
|
* \param[in] results The output result of detector
|
||||||
|
* \param[in] batch_det_img_info The detector_preprocess result
|
||||||
|
* \return true if the postprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<std::vector<std::array<int, 8>>>* results,
|
||||||
|
const std::vector<std::array<int, 4>>& batch_det_img_info);
|
||||||
|
|
||||||
|
double det_db_thresh_ = 0.3;
|
||||||
|
double det_db_box_thresh_ = 0.6;
|
||||||
|
double det_db_unclip_ratio_ = 1.5;
|
||||||
|
std::string det_db_score_mode_ = "slow";
|
||||||
|
bool use_dilation_ = false;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_ = false;
|
||||||
|
PostProcessor post_processor_;
|
||||||
|
bool SingleBatchPostprocessor(const float* out_data,
|
||||||
|
int n2,
|
||||||
|
int n3,
|
||||||
|
const std::array<int, 4>& det_img_info,
|
||||||
|
std::vector<std::array<int, 8>>* boxes_result);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
113
fastdeploy/vision/ocr/ppocr/det_preprocessor.cc
Normal file
113
fastdeploy/vision/ocr/ppocr/det_preprocessor.cc
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
#include "fastdeploy/function/concat.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
DBDetectorPreprocessor::DBDetectorPreprocessor() {
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 4> OcrDetectorGetInfo(FDMat* img, int max_size_len) {
|
||||||
|
int w = img->Width();
|
||||||
|
int h = img->Height();
|
||||||
|
|
||||||
|
float ratio = 1.f;
|
||||||
|
int max_wh = w >= h ? w : h;
|
||||||
|
if (max_wh > max_size_len) {
|
||||||
|
if (h > w) {
|
||||||
|
ratio = float(max_size_len) / float(h);
|
||||||
|
} else {
|
||||||
|
ratio = float(max_size_len) / float(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int resize_h = int(float(h) * ratio);
|
||||||
|
int resize_w = int(float(w) * ratio);
|
||||||
|
resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32);
|
||||||
|
resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32);
|
||||||
|
|
||||||
|
return {w,h,resize_w,resize_h};
|
||||||
|
/*
|
||||||
|
*ratio_h = float(resize_h) / float(h);
|
||||||
|
*ratio_w = float(resize_w) / float(w);
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
bool OcrDetectorResizeImage(FDMat* img,
|
||||||
|
int resize_w,
|
||||||
|
int resize_h,
|
||||||
|
int max_resize_w,
|
||||||
|
int max_resize_h) {
|
||||||
|
Resize::Run(img, resize_w, resize_h);
|
||||||
|
std::vector<float> value = {0, 0, 0};
|
||||||
|
Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DBDetectorPreprocessor::Run(std::vector<FDMat>* images,
|
||||||
|
std::vector<FDTensor>* outputs,
|
||||||
|
std::vector<std::array<int, 4>>* batch_det_img_info_ptr) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (images->size() == 0) {
|
||||||
|
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int max_resize_w = 0;
|
||||||
|
int max_resize_h = 0;
|
||||||
|
std::vector<std::array<int, 4>>& batch_det_img_info = *batch_det_img_info_ptr;
|
||||||
|
batch_det_img_info.clear();
|
||||||
|
batch_det_img_info.resize(images->size());
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
FDMat* mat = &(images->at(i));
|
||||||
|
batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_);
|
||||||
|
max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]);
|
||||||
|
max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
FDMat* mat = &(images->at(i));
|
||||||
|
OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h);
|
||||||
|
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
/*
|
||||||
|
Normalize::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
HWC2CHW::Run(mat);
|
||||||
|
Cast::Run(mat, "float");
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
// Only have 1 output Tensor.
|
||||||
|
outputs->resize(1);
|
||||||
|
// Concat all the preprocessed data to a batch tensor
|
||||||
|
std::vector<FDTensor> tensors(images->size());
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
(*images)[i].ShareWithTensor(&(tensors[i]));
|
||||||
|
tensors[i].ExpandDim(0);
|
||||||
|
}
|
||||||
|
if (tensors.size() == 1) {
|
||||||
|
(*outputs)[0] = std::move(tensors[0]);
|
||||||
|
} else {
|
||||||
|
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
54
fastdeploy/vision/ocr/ppocr/det_preprocessor.h
Normal file
54
fastdeploy/vision/ocr/ppocr/det_preprocessor.h
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Preprocessor object for DBDetector serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL DBDetectorPreprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a preprocessor instance for DBDetector serials model
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
DBDetectorPreprocessor();
|
||||||
|
|
||||||
|
/** \brief Process the input image and prepare input tensors for runtime
|
||||||
|
*
|
||||||
|
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||||
|
* \param[in] outputs The output tensors which will feed in runtime
|
||||||
|
* \param[in] batch_det_img_info_ptr The output of preprocess
|
||||||
|
* \return true if the preprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(std::vector<FDMat>* images,
|
||||||
|
std::vector<FDTensor>* outputs,
|
||||||
|
std::vector<std::array<int, 4>>* batch_det_img_info_ptr);
|
||||||
|
|
||||||
|
int max_side_len_ = 960;
|
||||||
|
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||||
|
std::vector<float> scale_ = {0.229f, 0.224f, 0.225f};
|
||||||
|
bool is_scale_ = true;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
166
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Normal file → Executable file
166
fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc
Normal file → Executable file
@@ -16,45 +16,171 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
void BindPPOCRModel(pybind11::module& m) {
|
void BindPPOCRModel(pybind11::module& m) {
|
||||||
|
m.def("sort_boxes", [](std::vector<std::array<int, 8>>& boxes) {
|
||||||
|
vision::ocr::SortBoxes(&boxes);
|
||||||
|
return boxes;
|
||||||
|
});
|
||||||
// DBDetector
|
// DBDetector
|
||||||
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
|
pybind11::class_<vision::ocr::DBDetector, FastDeployModel>(m, "DBDetector")
|
||||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||||
ModelFormat>())
|
ModelFormat>())
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_)
|
||||||
|
.def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_);
|
||||||
|
|
||||||
.def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len)
|
pybind11::class_<vision::ocr::DBDetectorPreprocessor>(m, "DBDetectorPreprocessor")
|
||||||
.def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh)
|
.def(pybind11::init<>())
|
||||||
.def_readwrite("det_db_box_thresh",
|
.def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_)
|
||||||
&vision::ocr::DBDetector::det_db_box_thresh)
|
.def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_)
|
||||||
.def_readwrite("det_db_unclip_ratio",
|
.def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_)
|
||||||
&vision::ocr::DBDetector::det_db_unclip_ratio)
|
.def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_)
|
||||||
.def_readwrite("det_db_score_mode",
|
.def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||||
&vision::ocr::DBDetector::det_db_score_mode)
|
std::vector<vision::FDMat> images;
|
||||||
.def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation)
|
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||||
.def_readwrite("mean", &vision::ocr::DBDetector::mean)
|
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||||
.def_readwrite("scale", &vision::ocr::DBDetector::scale)
|
}
|
||||||
.def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale);
|
std::vector<FDTensor> outputs;
|
||||||
|
std::vector<std::array<int, 4>> batch_det_img_info;
|
||||||
|
self.Run(&images, &outputs, &batch_det_img_info);
|
||||||
|
for(size_t i = 0; i< outputs.size(); ++i){
|
||||||
|
outputs[i].StopSharing();
|
||||||
|
}
|
||||||
|
return make_pair(outputs, batch_det_img_info);
|
||||||
|
});
|
||||||
|
|
||||||
|
pybind11::class_<vision::ocr::DBDetectorPostprocessor>(m, "DBDetectorPostprocessor")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_)
|
||||||
|
.def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_)
|
||||||
|
.def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_)
|
||||||
|
.def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_)
|
||||||
|
.def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_)
|
||||||
|
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
|
||||||
|
std::vector<FDTensor>& inputs,
|
||||||
|
const std::vector<std::array<int, 4>>& batch_det_img_info) {
|
||||||
|
std::vector<std::vector<std::array<int, 8>>> results;
|
||||||
|
|
||||||
|
if (!self.Run(inputs, &results, batch_det_img_info)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
})
|
||||||
|
.def("run", [](vision::ocr::DBDetectorPostprocessor& self,
|
||||||
|
std::vector<pybind11::array>& input_array,
|
||||||
|
const std::vector<std::array<int, 4>>& batch_det_img_info) {
|
||||||
|
std::vector<std::vector<std::array<int, 8>>> results;
|
||||||
|
std::vector<FDTensor> inputs;
|
||||||
|
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||||
|
if (!self.Run(inputs, &results, batch_det_img_info)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')");
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
});
|
||||||
|
|
||||||
// Classifier
|
// Classifier
|
||||||
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
|
pybind11::class_<vision::ocr::Classifier, FastDeployModel>(m, "Classifier")
|
||||||
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
.def(pybind11::init<std::string, std::string, RuntimeOption,
|
||||||
ModelFormat>())
|
ModelFormat>())
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_)
|
||||||
|
.def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_);
|
||||||
|
|
||||||
|
pybind11::class_<vision::ocr::ClassifierPreprocessor>(m, "ClassifierPreprocessor")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_)
|
||||||
|
.def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_)
|
||||||
|
.def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_)
|
||||||
|
.def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_)
|
||||||
|
.def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||||
|
std::vector<vision::FDMat> images;
|
||||||
|
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||||
|
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||||
|
}
|
||||||
|
std::vector<FDTensor> outputs;
|
||||||
|
if (!self.Run(&images, &outputs)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPreprocessor.')");
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i< outputs.size(); ++i){
|
||||||
|
outputs[i].StopSharing();
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
});
|
||||||
|
|
||||||
|
pybind11::class_<vision::ocr::ClassifierPostprocessor>(m, "ClassifierPostprocessor")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_)
|
||||||
|
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
|
||||||
|
std::vector<FDTensor>& inputs) {
|
||||||
|
std::vector<int> cls_labels;
|
||||||
|
std::vector<float> cls_scores;
|
||||||
|
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
|
||||||
|
}
|
||||||
|
return make_pair(cls_labels,cls_scores);
|
||||||
|
})
|
||||||
|
.def("run", [](vision::ocr::ClassifierPostprocessor& self,
|
||||||
|
std::vector<pybind11::array>& input_array) {
|
||||||
|
std::vector<FDTensor> inputs;
|
||||||
|
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||||
|
std::vector<int> cls_labels;
|
||||||
|
std::vector<float> cls_scores;
|
||||||
|
if (!self.Run(inputs, &cls_labels, &cls_scores)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')");
|
||||||
|
}
|
||||||
|
return make_pair(cls_labels,cls_scores);
|
||||||
|
});
|
||||||
|
|
||||||
.def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh)
|
|
||||||
.def_readwrite("cls_image_shape",
|
|
||||||
&vision::ocr::Classifier::cls_image_shape)
|
|
||||||
.def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num);
|
|
||||||
|
|
||||||
// Recognizer
|
// Recognizer
|
||||||
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
|
pybind11::class_<vision::ocr::Recognizer, FastDeployModel>(m, "Recognizer")
|
||||||
|
|
||||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||||
ModelFormat>())
|
ModelFormat>())
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_)
|
||||||
|
.def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_);
|
||||||
|
|
||||||
.def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h)
|
pybind11::class_<vision::ocr::RecognizerPreprocessor>(m, "RecognizerPreprocessor")
|
||||||
.def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w)
|
.def(pybind11::init<>())
|
||||||
.def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num);
|
.def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_)
|
||||||
|
.def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_)
|
||||||
|
.def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_)
|
||||||
|
.def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_)
|
||||||
|
.def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||||
|
std::vector<vision::FDMat> images;
|
||||||
|
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||||
|
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||||
|
}
|
||||||
|
std::vector<FDTensor> outputs;
|
||||||
|
if (!self.Run(&images, &outputs)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPreprocessor.')");
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i< outputs.size(); ++i){
|
||||||
|
outputs[i].StopSharing();
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
});
|
||||||
|
|
||||||
|
pybind11::class_<vision::ocr::RecognizerPostprocessor>(m, "RecognizerPostprocessor")
|
||||||
|
.def(pybind11::init<std::string>())
|
||||||
|
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
|
||||||
|
std::vector<FDTensor>& inputs) {
|
||||||
|
std::vector<std::string> texts;
|
||||||
|
std::vector<float> rec_scores;
|
||||||
|
if (!self.Run(inputs, &texts, &rec_scores)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
|
||||||
|
}
|
||||||
|
return make_pair(texts, rec_scores);
|
||||||
|
})
|
||||||
|
.def("run", [](vision::ocr::RecognizerPostprocessor& self,
|
||||||
|
std::vector<pybind11::array>& input_array) {
|
||||||
|
std::vector<FDTensor> inputs;
|
||||||
|
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||||
|
std::vector<std::string> texts;
|
||||||
|
std::vector<float> rec_scores;
|
||||||
|
if (!self.Run(inputs, &texts, &rec_scores)) {
|
||||||
|
pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')");
|
||||||
|
}
|
||||||
|
return make_pair(texts, rec_scores);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
18
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc
Normal file → Executable file
18
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc
Normal file → Executable file
@@ -31,6 +31,15 @@ void BindPPOCRv3(pybind11::module& m) {
|
|||||||
vision::OCRResult res;
|
vision::OCRResult res;
|
||||||
self.Predict(&mat, &res);
|
self.Predict(&mat, &res);
|
||||||
return res;
|
return res;
|
||||||
|
})
|
||||||
|
.def("batch_predict", [](pipeline::PPOCRv3& self, std::vector<pybind11::array>& data) {
|
||||||
|
std::vector<cv::Mat> images;
|
||||||
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
|
images.push_back(PyArrayToCvMat(data[i]));
|
||||||
|
}
|
||||||
|
std::vector<vision::OCRResult> results;
|
||||||
|
self.BatchPredict(images, &results);
|
||||||
|
return results;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +58,15 @@ void BindPPOCRv2(pybind11::module& m) {
|
|||||||
vision::OCRResult res;
|
vision::OCRResult res;
|
||||||
self.Predict(&mat, &res);
|
self.Predict(&mat, &res);
|
||||||
return res;
|
return res;
|
||||||
|
})
|
||||||
|
.def("batch_predict", [](pipeline::PPOCRv2& self, std::vector<pybind11::array>& data) {
|
||||||
|
std::vector<cv::Mat> images;
|
||||||
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
|
images.push_back(PyArrayToCvMat(data[i]));
|
||||||
|
}
|
||||||
|
std::vector<vision::OCRResult> results;
|
||||||
|
self.BatchPredict(images, &results);
|
||||||
|
return results;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
136
fastdeploy/vision/ocr/ppocr/ppocr_v2.cc
Normal file → Executable file
136
fastdeploy/vision/ocr/ppocr/ppocr_v2.cc
Normal file → Executable file
@@ -22,101 +22,95 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
|
|||||||
fastdeploy::vision::ocr::Classifier* cls_model,
|
fastdeploy::vision::ocr::Classifier* cls_model,
|
||||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||||
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
|
: detector_(det_model), classifier_(cls_model), recognizer_(rec_model) {
|
||||||
recognizer_->rec_image_shape[1] = 32;
|
Initialized();
|
||||||
|
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
|
PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
|
||||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||||
: detector_(det_model), recognizer_(rec_model) {
|
: detector_(det_model), recognizer_(rec_model) {
|
||||||
recognizer_->rec_image_shape[1] = 32;
|
Initialized();
|
||||||
|
recognizer_->preprocessor_.rec_image_shape_[1] = 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PPOCRv2::Initialized() const {
|
bool PPOCRv2::Initialized() const {
|
||||||
|
|
||||||
if (detector_ != nullptr && !detector_->Initialized()){
|
if (detector_ != nullptr && !detector_->Initialized()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (classifier_ != nullptr && !classifier_->Initialized()){
|
if (classifier_ != nullptr && !classifier_->Initialized()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (recognizer_ != nullptr && !recognizer_->Initialized()){
|
if (recognizer_ != nullptr && !recognizer_->Initialized()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PPOCRv2::Detect(cv::Mat* img,
|
|
||||||
fastdeploy::vision::OCRResult* result) {
|
|
||||||
if (!detector_->Predict(img, &(result->boxes))) {
|
|
||||||
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
vision::ocr::SortBoxes(result);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool PPOCRv2::Recognize(cv::Mat* img,
|
|
||||||
fastdeploy::vision::OCRResult* result) {
|
|
||||||
std::tuple<std::string, float> rec_result;
|
|
||||||
if (!recognizer_->Predict(img, &rec_result)) {
|
|
||||||
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
result->text.push_back(std::get<0>(rec_result));
|
|
||||||
result->rec_scores.push_back(std::get<1>(rec_result));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool PPOCRv2::Classify(cv::Mat* img,
|
|
||||||
fastdeploy::vision::OCRResult* result) {
|
|
||||||
std::tuple<int, float> cls_result;
|
|
||||||
|
|
||||||
if (!classifier_->Predict(img, &cls_result)) {
|
|
||||||
FDERROR << "There's error while classifying image in PPOCR." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
result->cls_labels.push_back(std::get<0>(cls_result));
|
|
||||||
result->cls_scores.push_back(std::get<1>(cls_result));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool PPOCRv2::Predict(cv::Mat* img,
|
bool PPOCRv2::Predict(cv::Mat* img,
|
||||||
fastdeploy::vision::OCRResult* result) {
|
fastdeploy::vision::OCRResult* result) {
|
||||||
result->Clear();
|
std::vector<fastdeploy::vision::OCRResult> batch_result(1);
|
||||||
if (nullptr != detector_ && !Detect(img, result)) {
|
BatchPredict({*img},&batch_result);
|
||||||
FDERROR << "Failed to detect image." << std::endl;
|
*result = std::move(batch_result[0]);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get croped images by detection result
|
|
||||||
std::vector<cv::Mat> image_list;
|
|
||||||
for (size_t i = 0; i < result->boxes.size(); ++i) {
|
|
||||||
auto crop_im = vision::ocr::GetRotateCropImage(*img, (result->boxes)[i]);
|
|
||||||
image_list.push_back(crop_im);
|
|
||||||
}
|
|
||||||
if (result->boxes.size() == 0) {
|
|
||||||
image_list.push_back(*img);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < image_list.size(); ++i) {
|
|
||||||
if (nullptr != classifier_ && !Classify(&(image_list[i]), result)) {
|
|
||||||
FDERROR << "Failed to classify croped image of index " << i << "." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (nullptr != classifier_ && result->cls_labels[i] % 2 == 1 && result->cls_scores[i] > classifier_->cls_thresh) {
|
|
||||||
cv::rotate(image_list[i], image_list[i], 1);
|
|
||||||
}
|
|
||||||
if (nullptr != recognizer_ && !Recognize(&(image_list[i]), result)) {
|
|
||||||
FDERROR << "Failed to recgnize croped image of index " << i << "." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool PPOCRv2::BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<fastdeploy::vision::OCRResult>* batch_result) {
|
||||||
|
batch_result->clear();
|
||||||
|
batch_result->resize(images.size());
|
||||||
|
std::vector<std::vector<std::array<int, 8>>> batch_boxes(images.size());
|
||||||
|
|
||||||
|
if (!detector_->BatchPredict(images, &batch_boxes)) {
|
||||||
|
FDERROR << "There's error while detecting image in PPOCR." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) {
|
||||||
|
vision::ocr::SortBoxes(&(batch_boxes[i_batch]));
|
||||||
|
(*batch_result)[i_batch].boxes = batch_boxes[i_batch];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for(int i_batch = 0; i_batch < images.size(); ++i_batch) {
|
||||||
|
fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch];
|
||||||
|
// Get croped images by detection result
|
||||||
|
const std::vector<std::array<int, 8>>& boxes = ocr_result.boxes;
|
||||||
|
const cv::Mat& img = images[i_batch];
|
||||||
|
std::vector<cv::Mat> image_list;
|
||||||
|
if (boxes.size() == 0) {
|
||||||
|
image_list.emplace_back(img);
|
||||||
|
}else{
|
||||||
|
image_list.resize(boxes.size());
|
||||||
|
for (size_t i_box = 0; i_box < boxes.size(); ++i_box) {
|
||||||
|
image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<int32_t>* cls_labels_ptr = &ocr_result.cls_labels;
|
||||||
|
std::vector<float>* cls_scores_ptr = &ocr_result.cls_scores;
|
||||||
|
|
||||||
|
std::vector<std::string>* text_ptr = &ocr_result.text;
|
||||||
|
std::vector<float>* rec_scores_ptr = &ocr_result.rec_scores;
|
||||||
|
|
||||||
|
if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) {
|
||||||
|
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
|
||||||
|
return false;
|
||||||
|
}else{
|
||||||
|
for (size_t i_img = 0; i_img < image_list.size(); ++i_img) {
|
||||||
|
if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) {
|
||||||
|
cv::rotate(image_list[i_img], image_list[i_img], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) {
|
||||||
|
FDERROR << "There's error while recognizing image in PPOCR." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namesapce pipeline
|
} // namesapce pipeline
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
13
fastdeploy/vision/ocr/ppocr/ppocr_v2.h
Normal file → Executable file
13
fastdeploy/vision/ocr/ppocr/ppocr_v2.h
Normal file → Executable file
@@ -59,6 +59,14 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
|
|||||||
* \return true if the prediction successed, otherwise false.
|
* \return true if the prediction successed, otherwise false.
|
||||||
*/
|
*/
|
||||||
virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
||||||
|
/** \brief BatchPredict the input image and get OCR result.
|
||||||
|
*
|
||||||
|
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
|
* \param[in] batch_result The output list of OCR result will be writen to this structure.
|
||||||
|
* \return true if the prediction successed, otherwise false.
|
||||||
|
*/
|
||||||
|
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
|
std::vector<fastdeploy::vision::OCRResult>* batch_result);
|
||||||
bool Initialized() const override;
|
bool Initialized() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@@ -66,11 +74,6 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
|
|||||||
fastdeploy::vision::ocr::Classifier* classifier_ = nullptr;
|
fastdeploy::vision::ocr::Classifier* classifier_ = nullptr;
|
||||||
fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr;
|
fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr;
|
||||||
/// Launch the detection process in OCR.
|
/// Launch the detection process in OCR.
|
||||||
virtual bool Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
|
||||||
/// Launch the recognition process in OCR.
|
|
||||||
virtual bool Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
|
||||||
/// Launch the classification process in OCR.
|
|
||||||
virtual bool Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace application {
|
namespace application {
|
||||||
|
4
fastdeploy/vision/ocr/ppocr/ppocr_v3.h
Normal file → Executable file
4
fastdeploy/vision/ocr/ppocr/ppocr_v3.h
Normal file → Executable file
@@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
|
|||||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||||
: PPOCRv2(det_model, cls_model, rec_model) {
|
: PPOCRv2(det_model, cls_model, rec_model) {
|
||||||
// The only difference between v2 and v3
|
// The only difference between v2 and v3
|
||||||
recognizer_->rec_image_shape[1] = 48;
|
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
|
||||||
}
|
}
|
||||||
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
|
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
|
||||||
*
|
*
|
||||||
@@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
|
|||||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||||
: PPOCRv2(det_model, rec_model) {
|
: PPOCRv2(det_model, rec_model) {
|
||||||
// The only difference between v2 and v3
|
// The only difference between v2 and v3
|
||||||
recognizer_->rec_image_shape[1] = 48;
|
recognizer_->preprocessor_.rec_image_shape_[1] = 48;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
112
fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc
Normal file
112
fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
std::vector<std::string> ReadDict(const std::string& path) {
|
||||||
|
std::ifstream in(path);
|
||||||
|
FDASSERT(in, "Cannot open file %s to read.", path.c_str());
|
||||||
|
std::string line;
|
||||||
|
std::vector<std::string> m_vec;
|
||||||
|
while (getline(in, line)) {
|
||||||
|
m_vec.push_back(line);
|
||||||
|
}
|
||||||
|
m_vec.insert(m_vec.begin(), "#"); // blank char for ctc
|
||||||
|
m_vec.push_back(" ");
|
||||||
|
return m_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
RecognizerPostprocessor::RecognizerPostprocessor(){
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) {
|
||||||
|
// init label_lsit
|
||||||
|
label_list_ = ReadDict(label_path);
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
std::string* text, float* rec_score) {
|
||||||
|
std::string& str_res = *text;
|
||||||
|
float& score = *rec_score;
|
||||||
|
score = 0.f;
|
||||||
|
int argmax_idx;
|
||||||
|
int last_index = 0;
|
||||||
|
int count = 0;
|
||||||
|
float max_value = 0.0f;
|
||||||
|
|
||||||
|
for (int n = 0; n < output_shape[1]; n++) {
|
||||||
|
argmax_idx = int(
|
||||||
|
std::distance(&out_data[n * output_shape[2]],
|
||||||
|
std::max_element(&out_data[n * output_shape[2]],
|
||||||
|
&out_data[(n + 1) * output_shape[2]])));
|
||||||
|
|
||||||
|
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
|
||||||
|
&out_data[(n + 1) * output_shape[2]]));
|
||||||
|
|
||||||
|
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
|
||||||
|
score += max_value;
|
||||||
|
count += 1;
|
||||||
|
if(argmax_idx > label_list_.size()) {
|
||||||
|
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
|
||||||
|
<< label_list_.size() << ". Please check the label file!" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
str_res += label_list_[argmax_idx];
|
||||||
|
}
|
||||||
|
last_index = argmax_idx;
|
||||||
|
}
|
||||||
|
score /= (count + 1e-6);
|
||||||
|
if (count == 0 || std::isnan(score)) {
|
||||||
|
score = 0.f;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecognizerPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<std::string>* texts, std::vector<float>* rec_scores) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "Postprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Recognizer have only 1 output tensor.
|
||||||
|
const FDTensor& tensor = tensors[0];
|
||||||
|
// For Recognizer, the output tensor shape = [batch, ?, 6625]
|
||||||
|
size_t batch = tensor.shape[0];
|
||||||
|
size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies<int>());
|
||||||
|
|
||||||
|
texts->resize(batch);
|
||||||
|
rec_scores->resize(batch);
|
||||||
|
const float* tensor_data = reinterpret_cast<const float*>(tensor.Data());
|
||||||
|
for (int i_batch = 0; i_batch < batch; ++i_batch) {
|
||||||
|
if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tensor_data = tensor_data + length;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
55
fastdeploy/vision/ocr/ppocr/rec_postprocessor.h
Normal file
55
fastdeploy/vision/ocr/ppocr/rec_postprocessor.h
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Postprocessor object for Recognizer serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL RecognizerPostprocessor {
|
||||||
|
public:
|
||||||
|
RecognizerPostprocessor();
|
||||||
|
/** \brief Create a postprocessor instance for Recognizer serials model
|
||||||
|
*
|
||||||
|
* \param[in] label_path The path of label_dict
|
||||||
|
*/
|
||||||
|
explicit RecognizerPostprocessor(const std::string& label_path);
|
||||||
|
|
||||||
|
/** \brief Process the result of runtime and fill to ClassifyResult structure
|
||||||
|
*
|
||||||
|
* \param[in] tensors The inference result from runtime
|
||||||
|
* \param[in] texts The output result of recognizer
|
||||||
|
* \param[in] rec_scores The output result of recognizer
|
||||||
|
* \return true if the postprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(const std::vector<FDTensor>& tensors,
|
||||||
|
std::vector<std::string>* texts, std::vector<float>* rec_scores);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool SingleBatchPostprocessor(const float* out_data,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
std::string* text, float* rec_score);
|
||||||
|
bool initialized_ = false;
|
||||||
|
std::vector<std::string> label_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
99
fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc
Normal file
99
fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
|
||||||
|
#include "fastdeploy/utils/perf.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h"
|
||||||
|
#include "fastdeploy/function/concat.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
namespace ocr {
|
||||||
|
|
||||||
|
RecognizerPreprocessor::RecognizerPreprocessor() {
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio,
|
||||||
|
const std::vector<int>& rec_image_shape) {
|
||||||
|
int imgC, imgH, imgW;
|
||||||
|
imgC = rec_image_shape[0];
|
||||||
|
imgH = rec_image_shape[1];
|
||||||
|
imgW = rec_image_shape[2];
|
||||||
|
|
||||||
|
imgW = int(imgH * max_wh_ratio);
|
||||||
|
|
||||||
|
float ratio = float(mat->Width()) / float(mat->Height());
|
||||||
|
int resize_w;
|
||||||
|
if (ceilf(imgH * ratio) > imgW) {
|
||||||
|
resize_w = imgW;
|
||||||
|
}else{
|
||||||
|
resize_w = int(ceilf(imgH * ratio));
|
||||||
|
}
|
||||||
|
Resize::Run(mat, resize_w, imgH);
|
||||||
|
|
||||||
|
std::vector<float> value = {0, 0, 0};
|
||||||
|
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecognizerPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||||
|
if (!initialized_) {
|
||||||
|
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (images->size() == 0) {
|
||||||
|
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int imgH = rec_image_shape_[1];
|
||||||
|
int imgW = rec_image_shape_[2];
|
||||||
|
float max_wh_ratio = imgW * 1.0 / imgH;
|
||||||
|
float ori_wh_ratio;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
FDMat* mat = &(images->at(i));
|
||||||
|
ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
|
||||||
|
max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
FDMat* mat = &(images->at(i));
|
||||||
|
OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_);
|
||||||
|
NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
/*
|
||||||
|
Normalize::Run(mat, mean_, scale_, is_scale_);
|
||||||
|
HWC2CHW::Run(mat);
|
||||||
|
Cast::Run(mat, "float");
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
// Only have 1 output Tensor.
|
||||||
|
outputs->resize(1);
|
||||||
|
// Concat all the preprocessed data to a batch tensor
|
||||||
|
std::vector<FDTensor> tensors(images->size());
|
||||||
|
for (size_t i = 0; i < images->size(); ++i) {
|
||||||
|
(*images)[i].ShareWithTensor(&(tensors[i]));
|
||||||
|
tensors[i].ExpandDim(0);
|
||||||
|
}
|
||||||
|
if (tensors.size() == 1) {
|
||||||
|
(*outputs)[0] = std::move(tensors[0]);
|
||||||
|
} else {
|
||||||
|
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
52
fastdeploy/vision/ocr/ppocr/rec_preprocessor.h
Normal file
52
fastdeploy/vision/ocr/ppocr/rec_preprocessor.h
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
|
#include "fastdeploy/vision/common/result.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
namespace ocr {
|
||||||
|
/*! @brief Preprocessor object for PaddleClas serials model.
|
||||||
|
*/
|
||||||
|
class FASTDEPLOY_DECL RecognizerPreprocessor {
|
||||||
|
public:
|
||||||
|
/** \brief Create a preprocessor instance for PaddleClas serials model
|
||||||
|
*
|
||||||
|
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
|
||||||
|
*/
|
||||||
|
RecognizerPreprocessor();
|
||||||
|
|
||||||
|
/** \brief Process the input image and prepare input tensors for runtime
|
||||||
|
*
|
||||||
|
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||||
|
* \param[in] outputs The output tensors which will feed in runtime
|
||||||
|
* \return true if the preprocess successed, otherwise false
|
||||||
|
*/
|
||||||
|
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||||
|
|
||||||
|
std::vector<int> rec_image_shape_ = {3, 48, 320};
|
||||||
|
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||||
|
std::vector<float> scale_ = {0.5f, 0.5f, 0.5f};
|
||||||
|
bool is_scale_ = true;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ocr
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace fastdeploy
|
152
fastdeploy/vision/ocr/ppocr/recognizer.cc
Normal file → Executable file
152
fastdeploy/vision/ocr/ppocr/recognizer.cc
Normal file → Executable file
@@ -20,29 +20,13 @@ namespace fastdeploy {
|
|||||||
namespace vision {
|
namespace vision {
|
||||||
namespace ocr {
|
namespace ocr {
|
||||||
|
|
||||||
std::vector<std::string> ReadDict(const std::string& path) {
|
|
||||||
std::ifstream in(path);
|
|
||||||
std::string line;
|
|
||||||
std::vector<std::string> m_vec;
|
|
||||||
if (in) {
|
|
||||||
while (getline(in, line)) {
|
|
||||||
m_vec.push_back(line);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::cout << "no such label file: " << path << ", exit the program..."
|
|
||||||
<< std::endl;
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
return m_vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
Recognizer::Recognizer() {}
|
Recognizer::Recognizer() {}
|
||||||
|
|
||||||
Recognizer::Recognizer(const std::string& model_file,
|
Recognizer::Recognizer(const std::string& model_file,
|
||||||
const std::string& params_file,
|
const std::string& params_file,
|
||||||
const std::string& label_path,
|
const std::string& label_path,
|
||||||
const RuntimeOption& custom_option,
|
const RuntimeOption& custom_option,
|
||||||
const ModelFormat& model_format) {
|
const ModelFormat& model_format):postprocessor_(label_path) {
|
||||||
if (model_format == ModelFormat::ONNX) {
|
if (model_format == ModelFormat::ONNX) {
|
||||||
valid_cpu_backends = {Backend::ORT,
|
valid_cpu_backends = {Backend::ORT,
|
||||||
Backend::OPENVINO};
|
Backend::OPENVINO};
|
||||||
@@ -56,27 +40,11 @@ Recognizer::Recognizer(const std::string& model_file,
|
|||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option.model_file = model_file;
|
runtime_option.model_file = model_file;
|
||||||
runtime_option.params_file = params_file;
|
runtime_option.params_file = params_file;
|
||||||
|
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
|
|
||||||
// init label_lsit
|
|
||||||
label_list = ReadDict(label_path);
|
|
||||||
label_list.insert(label_list.begin(), "#"); // blank char for ctc
|
|
||||||
label_list.push_back(" ");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init
|
// Init
|
||||||
bool Recognizer::Initialize() {
|
bool Recognizer::Initialize() {
|
||||||
// pre&post process parameters
|
|
||||||
rec_batch_num = 1;
|
|
||||||
rec_img_h = 48;
|
|
||||||
rec_img_w = 320;
|
|
||||||
rec_image_shape = {3, rec_img_h, rec_img_w};
|
|
||||||
|
|
||||||
mean = {0.5f, 0.5f, 0.5f};
|
|
||||||
scale = {0.5f, 0.5f, 0.5f};
|
|
||||||
is_scale = true;
|
|
||||||
|
|
||||||
if (!InitRuntime()) {
|
if (!InitRuntime()) {
|
||||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -85,119 +53,23 @@ bool Recognizer::Initialize() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OcrRecognizerResizeImage(Mat* mat, const float& wh_ratio,
|
bool Recognizer::BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
const std::vector<int>& rec_image_shape) {
|
std::vector<std::string>* texts, std::vector<float>* rec_scores) {
|
||||||
int imgC, imgH, imgW;
|
std::vector<FDMat> fd_images = WrapMat(images);
|
||||||
imgC = rec_image_shape[0];
|
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||||
imgH = rec_image_shape[1];
|
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||||
imgW = rec_image_shape[2];
|
return false;
|
||||||
|
|
||||||
imgW = int(imgH * wh_ratio);
|
|
||||||
|
|
||||||
float ratio = float(mat->Width()) / float(mat->Height());
|
|
||||||
int resize_w;
|
|
||||||
if (ceilf(imgH * ratio) > imgW)
|
|
||||||
resize_w = imgW;
|
|
||||||
else
|
|
||||||
resize_w = int(ceilf(imgH * ratio));
|
|
||||||
|
|
||||||
Resize::Run(mat, resize_w, imgH);
|
|
||||||
|
|
||||||
std::vector<float> value = {127, 127, 127};
|
|
||||||
Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Recognizer::Preprocess(Mat* mat, FDTensor* output,
|
|
||||||
const std::vector<int>& rec_image_shape) {
|
|
||||||
int imgH = rec_image_shape[1];
|
|
||||||
int imgW = rec_image_shape[2];
|
|
||||||
float wh_ratio = imgW * 1.0 / imgH;
|
|
||||||
|
|
||||||
float ori_wh_ratio = mat->Width() * 1.0 / mat->Height();
|
|
||||||
wh_ratio = std::max(wh_ratio, ori_wh_ratio);
|
|
||||||
|
|
||||||
OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape);
|
|
||||||
|
|
||||||
Normalize::Run(mat, mean, scale, true);
|
|
||||||
|
|
||||||
HWC2CHW::Run(mat);
|
|
||||||
Cast::Run(mat, "float");
|
|
||||||
|
|
||||||
mat->ShareWithTensor(output);
|
|
||||||
output->shape.insert(output->shape.begin(), 1);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Recognizer::Postprocess(FDTensor& infer_result,
|
|
||||||
std::tuple<std::string, float>* rec_result) {
|
|
||||||
std::vector<int64_t> output_shape = infer_result.shape;
|
|
||||||
FDASSERT(output_shape[0] == 1, "Only support batch =1 now.");
|
|
||||||
|
|
||||||
float* out_data = static_cast<float*>(infer_result.Data());
|
|
||||||
|
|
||||||
std::string str_res;
|
|
||||||
int argmax_idx;
|
|
||||||
int last_index = 0;
|
|
||||||
float score = 0.f;
|
|
||||||
int count = 0;
|
|
||||||
float max_value = 0.0f;
|
|
||||||
|
|
||||||
for (int n = 0; n < output_shape[1]; n++) {
|
|
||||||
argmax_idx = int(
|
|
||||||
std::distance(&out_data[n * output_shape[2]],
|
|
||||||
std::max_element(&out_data[n * output_shape[2]],
|
|
||||||
&out_data[(n + 1) * output_shape[2]])));
|
|
||||||
|
|
||||||
max_value = float(*std::max_element(&out_data[n * output_shape[2]],
|
|
||||||
&out_data[(n + 1) * output_shape[2]]));
|
|
||||||
|
|
||||||
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
|
|
||||||
score += max_value;
|
|
||||||
count += 1;
|
|
||||||
if(argmax_idx > label_list.size()){
|
|
||||||
FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: "
|
|
||||||
<< label_list.size() << ". Please check the label file!" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
str_res += label_list[argmax_idx];
|
|
||||||
}
|
|
||||||
last_index = argmax_idx;
|
|
||||||
}
|
}
|
||||||
score /= (count + 1e-6);
|
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||||
if (count == 0 || std::isnan(score)) {
|
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||||
score = 0.f;
|
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||||
}
|
|
||||||
std::get<0>(*rec_result) = str_res;
|
|
||||||
std::get<1>(*rec_result) = score;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Recognizer::Predict(cv::Mat* img,
|
|
||||||
std::tuple<std::string, float>* rec_result) {
|
|
||||||
Mat mat(*img);
|
|
||||||
|
|
||||||
std::vector<FDTensor> input_tensors(1);
|
|
||||||
|
|
||||||
if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) {
|
|
||||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) {
|
||||||
std::vector<FDTensor> output_tensors;
|
FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl;
|
||||||
|
|
||||||
if (!Infer(input_tensors, &output_tensors)) {
|
|
||||||
FDERROR << "Failed to inference." << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Postprocess(output_tensors[0], rec_result)) {
|
|
||||||
FDERROR << "Failed to post process." << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
31
fastdeploy/vision/ocr/ppocr/recognizer.h
Normal file → Executable file
31
fastdeploy/vision/ocr/ppocr/recognizer.h
Normal file → Executable file
@@ -17,6 +17,8 @@
|
|||||||
#include "fastdeploy/vision/common/processors/transform.h"
|
#include "fastdeploy/vision/common/processors/transform.h"
|
||||||
#include "fastdeploy/vision/common/result.h"
|
#include "fastdeploy/vision/common/result.h"
|
||||||
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
|
||||||
|
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -43,35 +45,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
|
|||||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||||
/// Get model's name
|
/// Get model's name
|
||||||
std::string ModelName() const { return "ppocr/ocr_rec"; }
|
std::string ModelName() const { return "ppocr/ocr_rec"; }
|
||||||
/** \brief Predict the input image and get OCR recognition model result.
|
/** \brief BatchPredict the input image and get OCR recognition model result.
|
||||||
*
|
*
|
||||||
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
|
||||||
* \param[in] rec_result The output of OCR recognition model result will be writen to this structure.
|
* \param[in] rec_results The output of OCR recognition model result will be writen to this structure.
|
||||||
* \return true if the prediction is successed, otherwise false.
|
* \return true if the prediction is successed, otherwise false.
|
||||||
*/
|
*/
|
||||||
virtual bool Predict(cv::Mat* img,
|
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||||
std::tuple<std::string, float>* rec_result);
|
std::vector<std::string>* texts, std::vector<float>* rec_scores);
|
||||||
|
|
||||||
// Pre & Post parameters
|
RecognizerPreprocessor preprocessor_;
|
||||||
std::vector<std::string> label_list;
|
RecognizerPostprocessor postprocessor_;
|
||||||
int rec_batch_num;
|
|
||||||
int rec_img_h;
|
|
||||||
int rec_img_w;
|
|
||||||
std::vector<int> rec_image_shape;
|
|
||||||
|
|
||||||
std::vector<float> mean;
|
|
||||||
std::vector<float> scale;
|
|
||||||
bool is_scale;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool Initialize();
|
bool Initialize();
|
||||||
/// Preprocess the input data, and set the preprocessed results to `outputs`
|
|
||||||
bool Preprocess(Mat* img, FDTensor* outputs,
|
|
||||||
const std::vector<int>& rec_image_shape);
|
|
||||||
/*! @brief Postprocess the inferenced results, and set the final result to `rec_result`
|
|
||||||
*/
|
|
||||||
bool Postprocess(FDTensor& infer_result,
|
|
||||||
std::tuple<std::string, float>* rec_result);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ocr
|
} // namespace ocr
|
||||||
|
10
fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc
Normal file → Executable file
10
fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc
Normal file → Executable file
@@ -318,10 +318,12 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<std::vector<int>>> PostProcessor::FilterTagDetRes(
|
std::vector<std::vector<std::vector<int>>> PostProcessor::FilterTagDetRes(
|
||||||
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
|
std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
float ratio_w, const std::map<std::string, std::array<float, 2>> &im_info) {
|
const std::array<int,4>& det_img_info) {
|
||||||
int oriimg_h = im_info.at("input_shape")[0];
|
int oriimg_w = det_img_info[0];
|
||||||
int oriimg_w = im_info.at("input_shape")[1];
|
int oriimg_h = det_img_info[1];
|
||||||
|
float ratio_w = float(det_img_info[2])/float(oriimg_w);
|
||||||
|
float ratio_h = float(det_img_info[3])/float(oriimg_h);
|
||||||
|
|
||||||
std::vector<std::vector<std::vector<int>>> root_points;
|
std::vector<std::vector<std::vector<int>>> root_points;
|
||||||
for (int n = 0; n < boxes.size(); n++) {
|
for (int n = 0; n < boxes.size(); n++) {
|
||||||
|
@@ -57,9 +57,8 @@ class PostProcessor {
|
|||||||
const float &det_db_unclip_ratio, const std::string &det_db_score_mode);
|
const float &det_db_unclip_ratio, const std::string &det_db_score_mode);
|
||||||
|
|
||||||
std::vector<std::vector<std::vector<int>>> FilterTagDetRes(
|
std::vector<std::vector<std::vector<int>>> FilterTagDetRes(
|
||||||
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
|
std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
float ratio_w,
|
const std::array<int, 4>& det_img_info);
|
||||||
const std::map<std::string, std::array<float, 2>> &im_info);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool XsortInt(std::vector<int> a, std::vector<int> b);
|
static bool XsortInt(std::vector<int> a, std::vector<int> b);
|
||||||
|
@@ -28,10 +28,10 @@ namespace fastdeploy {
|
|||||||
namespace vision {
|
namespace vision {
|
||||||
namespace ocr {
|
namespace ocr {
|
||||||
|
|
||||||
cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
|
FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage,
|
||||||
const std::array<int, 8>& box);
|
const std::array<int, 8>& box);
|
||||||
|
|
||||||
void SortBoxes(OCRResult* result);
|
FASTDEPLOY_DECL void SortBoxes(std::vector<std::array<int, 8>>* boxes);
|
||||||
|
|
||||||
} // namespace ocr
|
} // namespace ocr
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
@@ -29,17 +29,17 @@ bool CompareBox(const std::array<int, 8>& result1,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortBoxes(OCRResult* result) {
|
void SortBoxes(std::vector<std::array<int, 8>>* boxes) {
|
||||||
std::sort(result->boxes.begin(), result->boxes.end(), CompareBox);
|
std::sort(boxes->begin(), boxes->end(), CompareBox);
|
||||||
|
|
||||||
if (result->boxes.size() == 0) {
|
if (boxes->size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < result->boxes.size() - 1; i++) {
|
for (int i = 0; i < boxes->size() - 1; i++) {
|
||||||
if (abs(result->boxes[i + 1][1] - result->boxes[i][1]) < 10 &&
|
if (abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 &&
|
||||||
(result->boxes[i + 1][0] < result->boxes[i][0])) {
|
((*boxes)[i + 1][0] < (*boxes)[i][0])) {
|
||||||
std::swap(result->boxes[i], result->boxes[i + 1]);
|
std::swap((*boxes)[i], (*boxes)[i + 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
8
python/fastdeploy/vision/ocr/__init__.py
Normal file → Executable file
8
python/fastdeploy/vision/ocr/__init__.py
Normal file → Executable file
@@ -13,10 +13,4 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
from .ppocr import PPOCRv3
|
from .ppocr import *
|
||||||
from .ppocr import PPOCRv2
|
|
||||||
from .ppocr import PPOCRSystemv3
|
|
||||||
from .ppocr import PPOCRSystemv2
|
|
||||||
from .ppocr import DBDetector
|
|
||||||
from .ppocr import Classifier
|
|
||||||
from .ppocr import Recognizer
|
|
||||||
|
172
python/fastdeploy/vision/ocr/ppocr/__init__.py
Normal file → Executable file
172
python/fastdeploy/vision/ocr/ppocr/__init__.py
Normal file → Executable file
@@ -41,40 +41,11 @@ class DBDetector(FastDeployModel):
|
|||||||
assert self.initialized, "DBDetector initialize failed."
|
assert self.initialized, "DBDetector initialize failed."
|
||||||
|
|
||||||
# 一些跟DBDetector模型有关的属性封装
|
# 一些跟DBDetector模型有关的属性封装
|
||||||
@property
|
'''
|
||||||
def max_side_len(self):
|
|
||||||
return self._model.max_side_len
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def det_db_thresh(self):
|
def det_db_thresh(self):
|
||||||
return self._model.det_db_thresh
|
return self._model.det_db_thresh
|
||||||
|
|
||||||
@property
|
|
||||||
def det_db_box_thresh(self):
|
|
||||||
return self._model.det_db_box_thresh
|
|
||||||
|
|
||||||
@property
|
|
||||||
def det_db_unclip_ratio(self):
|
|
||||||
return self._model.det_db_unclip_ratio
|
|
||||||
|
|
||||||
@property
|
|
||||||
def det_db_score_mode(self):
|
|
||||||
return self._model.det_db_score_mode
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_dilation(self):
|
|
||||||
return self._model.use_dilation
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_scale(self):
|
|
||||||
return self._model.max_wh
|
|
||||||
|
|
||||||
@max_side_len.setter
|
|
||||||
def max_side_len(self, value):
|
|
||||||
assert isinstance(
|
|
||||||
value, int), "The value to set `max_side_len` must be type of int."
|
|
||||||
self._model.max_side_len = value
|
|
||||||
|
|
||||||
@det_db_thresh.setter
|
@det_db_thresh.setter
|
||||||
def det_db_thresh(self, value):
|
def det_db_thresh(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -82,6 +53,10 @@ class DBDetector(FastDeployModel):
|
|||||||
float), "The value to set `det_db_thresh` must be type of float."
|
float), "The value to set `det_db_thresh` must be type of float."
|
||||||
self._model.det_db_thresh = value
|
self._model.det_db_thresh = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def det_db_box_thresh(self):
|
||||||
|
return self._model.det_db_box_thresh
|
||||||
|
|
||||||
@det_db_box_thresh.setter
|
@det_db_box_thresh.setter
|
||||||
def det_db_box_thresh(self, value):
|
def det_db_box_thresh(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -89,6 +64,10 @@ class DBDetector(FastDeployModel):
|
|||||||
), "The value to set `det_db_box_thresh` must be type of float."
|
), "The value to set `det_db_box_thresh` must be type of float."
|
||||||
self._model.det_db_box_thresh = value
|
self._model.det_db_box_thresh = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def det_db_unclip_ratio(self):
|
||||||
|
return self._model.det_db_unclip_ratio
|
||||||
|
|
||||||
@det_db_unclip_ratio.setter
|
@det_db_unclip_ratio.setter
|
||||||
def det_db_unclip_ratio(self, value):
|
def det_db_unclip_ratio(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -96,6 +75,10 @@ class DBDetector(FastDeployModel):
|
|||||||
), "The value to set `det_db_unclip_ratio` must be type of float."
|
), "The value to set `det_db_unclip_ratio` must be type of float."
|
||||||
self._model.det_db_unclip_ratio = value
|
self._model.det_db_unclip_ratio = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def det_db_score_mode(self):
|
||||||
|
return self._model.det_db_score_mode
|
||||||
|
|
||||||
@det_db_score_mode.setter
|
@det_db_score_mode.setter
|
||||||
def det_db_score_mode(self, value):
|
def det_db_score_mode(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -103,6 +86,10 @@ class DBDetector(FastDeployModel):
|
|||||||
str), "The value to set `det_db_score_mode` must be type of str."
|
str), "The value to set `det_db_score_mode` must be type of str."
|
||||||
self._model.det_db_score_mode = value
|
self._model.det_db_score_mode = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_dilation(self):
|
||||||
|
return self._model.use_dilation
|
||||||
|
|
||||||
@use_dilation.setter
|
@use_dilation.setter
|
||||||
def use_dilation(self, value):
|
def use_dilation(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -110,11 +97,26 @@ class DBDetector(FastDeployModel):
|
|||||||
bool), "The value to set `use_dilation` must be type of bool."
|
bool), "The value to set `use_dilation` must be type of bool."
|
||||||
self._model.use_dilation = value
|
self._model.use_dilation = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_side_len(self):
|
||||||
|
return self._model.max_side_len
|
||||||
|
|
||||||
|
@max_side_len.setter
|
||||||
|
def max_side_len(self, value):
|
||||||
|
assert isinstance(
|
||||||
|
value, int), "The value to set `max_side_len` must be type of int."
|
||||||
|
self._model.max_side_len = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_scale(self):
|
||||||
|
return self._model.max_wh
|
||||||
|
|
||||||
@is_scale.setter
|
@is_scale.setter
|
||||||
def is_scale(self, value):
|
def is_scale(self, value):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, bool), "The value to set `is_scale` must be type of bool."
|
value, bool), "The value to set `is_scale` must be type of bool."
|
||||||
self._model.is_scale = value
|
self._model.is_scale = value
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
class Classifier(FastDeployModel):
|
class Classifier(FastDeployModel):
|
||||||
@@ -139,6 +141,7 @@ class Classifier(FastDeployModel):
|
|||||||
model_file, params_file, self._runtime_option, model_format)
|
model_file, params_file, self._runtime_option, model_format)
|
||||||
assert self.initialized, "Classifier initialize failed."
|
assert self.initialized, "Classifier initialize failed."
|
||||||
|
|
||||||
|
'''
|
||||||
@property
|
@property
|
||||||
def cls_thresh(self):
|
def cls_thresh(self):
|
||||||
return self._model.cls_thresh
|
return self._model.cls_thresh
|
||||||
@@ -170,6 +173,7 @@ class Classifier(FastDeployModel):
|
|||||||
value,
|
value,
|
||||||
int), "The value to set `cls_batch_num` must be type of int."
|
int), "The value to set `cls_batch_num` must be type of int."
|
||||||
self._model.cls_batch_num = value
|
self._model.cls_batch_num = value
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
class Recognizer(FastDeployModel):
|
class Recognizer(FastDeployModel):
|
||||||
@@ -197,6 +201,7 @@ class Recognizer(FastDeployModel):
|
|||||||
model_format)
|
model_format)
|
||||||
assert self.initialized, "Recognizer initialize failed."
|
assert self.initialized, "Recognizer initialize failed."
|
||||||
|
|
||||||
|
'''
|
||||||
@property
|
@property
|
||||||
def rec_img_h(self):
|
def rec_img_h(self):
|
||||||
return self._model.rec_img_h
|
return self._model.rec_img_h
|
||||||
@@ -227,6 +232,7 @@ class Recognizer(FastDeployModel):
|
|||||||
value,
|
value,
|
||||||
int), "The value to set `rec_batch_num` must be type of int."
|
int), "The value to set `rec_batch_num` must be type of int."
|
||||||
self._model.rec_batch_num = value
|
self._model.rec_batch_num = value
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
class PPOCRv3(FastDeployModel):
|
class PPOCRv3(FastDeployModel):
|
||||||
@@ -253,6 +259,14 @@ class PPOCRv3(FastDeployModel):
|
|||||||
"""
|
"""
|
||||||
return self.system.predict(input_image)
|
return self.system.predict(input_image)
|
||||||
|
|
||||||
|
def batch_predict(self, images):
|
||||||
|
"""Predict a batch of input image
|
||||||
|
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||||
|
:return: OCRBatchResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.system.batch_predict(images)
|
||||||
|
|
||||||
|
|
||||||
class PPOCRSystemv3(PPOCRv3):
|
class PPOCRSystemv3(PPOCRv3):
|
||||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||||
@@ -289,6 +303,14 @@ class PPOCRv2(FastDeployModel):
|
|||||||
"""
|
"""
|
||||||
return self.system.predict(input_image)
|
return self.system.predict(input_image)
|
||||||
|
|
||||||
|
def batch_predict(self, images):
|
||||||
|
"""Predict a batch of input image
|
||||||
|
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||||
|
:return: OCRBatchResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.system.batch_predict(images)
|
||||||
|
|
||||||
|
|
||||||
class PPOCRSystemv2(PPOCRv2):
|
class PPOCRSystemv2(PPOCRv2):
|
||||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||||
@@ -299,3 +321,93 @@ class PPOCRSystemv2(PPOCRv2):
|
|||||||
|
|
||||||
def predict(self, input_image):
|
def predict(self, input_image):
|
||||||
return super(PPOCRSystemv2, self).predict(input_image)
|
return super(PPOCRSystemv2, self).predict(input_image)
|
||||||
|
|
||||||
|
|
||||||
|
class DBDetectorPreprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a preprocessor for DBDetectorModel
|
||||||
|
"""
|
||||||
|
self._preprocessor = C.vision.ocr.DBDetectorPreprocessor()
|
||||||
|
|
||||||
|
def run(self, input_ims):
|
||||||
|
"""Preprocess input images for DBDetectorModel
|
||||||
|
:param: input_ims: (list of numpy.ndarray) The input image
|
||||||
|
:return: pair(list of FDTensor, list of std::array<int, 4>)
|
||||||
|
"""
|
||||||
|
return self._preprocessor.run(input_ims)
|
||||||
|
|
||||||
|
|
||||||
|
class DBDetectorPostprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a postprocessor for DBDetectorModel
|
||||||
|
"""
|
||||||
|
self._postprocessor = C.vision.ocr.DBDetectorPostprocessor()
|
||||||
|
|
||||||
|
def run(self, runtime_results, batch_det_img_info):
|
||||||
|
"""Postprocess the runtime results for DBDetectorModel
|
||||||
|
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
|
||||||
|
:param: batch_det_img_info: (list of std::array<int, 4>)The output of det_preprocessor
|
||||||
|
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
|
||||||
|
"""
|
||||||
|
return self._postprocessor.run(runtime_results, batch_det_img_info)
|
||||||
|
|
||||||
|
|
||||||
|
class RecognizerPreprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a preprocessor for RecognizerModel
|
||||||
|
"""
|
||||||
|
self._preprocessor = C.vision.ocr.RecognizerPreprocessor()
|
||||||
|
|
||||||
|
def run(self, input_ims):
|
||||||
|
"""Preprocess input images for RecognizerModel
|
||||||
|
:param: input_ims: (list of numpy.ndarray)The input image
|
||||||
|
:return: list of FDTensor
|
||||||
|
"""
|
||||||
|
return self._preprocessor.run(input_ims)
|
||||||
|
|
||||||
|
|
||||||
|
class RecognizerPostprocessor:
|
||||||
|
def __init__(self, label_path):
|
||||||
|
"""Create a postprocessor for RecognizerModel
|
||||||
|
:param label_path: (str)Path of label file
|
||||||
|
"""
|
||||||
|
self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path)
|
||||||
|
|
||||||
|
def run(self, runtime_results):
|
||||||
|
"""Postprocess the runtime results for RecognizerModel
|
||||||
|
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
|
||||||
|
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
|
||||||
|
"""
|
||||||
|
return self._postprocessor.run(runtime_results)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierPreprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a preprocessor for ClassifierModel
|
||||||
|
"""
|
||||||
|
self._preprocessor = C.vision.ocr.ClassifierPreprocessor()
|
||||||
|
|
||||||
|
def run(self, input_ims):
|
||||||
|
"""Preprocess input images for ClassifierModel
|
||||||
|
:param: input_ims: (list of numpy.ndarray)The input image
|
||||||
|
:return: list of FDTensor
|
||||||
|
"""
|
||||||
|
return self._preprocessor.run(input_ims)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierPostprocessor:
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a postprocessor for ClassifierModel
|
||||||
|
"""
|
||||||
|
self._postprocessor = C.vision.ocr.ClassifierPostprocessor()
|
||||||
|
|
||||||
|
def run(self, runtime_results):
|
||||||
|
"""Postprocess the runtime results for ClassifierModel
|
||||||
|
:param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime
|
||||||
|
:return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
|
||||||
|
"""
|
||||||
|
return self._postprocessor.run(runtime_results)
|
||||||
|
|
||||||
|
|
||||||
|
def sort_boxes(boxes):
|
||||||
|
return C.vision.ocr.sort_boxes(boxes)
|
||||||
|
Reference in New Issue
Block a user