diff --git a/fastdeploy/utils/utils.h b/fastdeploy/utils/utils.h index 2c7269763..994ea9baa 100644 --- a/fastdeploy/utils/utils.h +++ b/fastdeploy/utils/utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #if defined(_WIN32) #ifdef FASTDEPLOY_LIB diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 82f06e003..581d3b91a 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -48,6 +48,7 @@ #include "fastdeploy/vision/matting/ppmatting/ppmatting.h" #include "fastdeploy/vision/ocr/ppocr/classifier.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h" diff --git a/fastdeploy/vision/ocr/ppocr/classifier.cc b/fastdeploy/vision/ocr/ppocr/classifier.cc old mode 100644 new mode 100755 index 1fbd4cc36..130329735 --- a/fastdeploy/vision/ocr/ppocr/classifier.cc +++ b/fastdeploy/vision/ocr/ppocr/classifier.cc @@ -41,16 +41,7 @@ Classifier::Classifier(const std::string& model_file, initialized = Initialize(); } -// Init bool Classifier::Initialize() { - // pre&post process parameters - cls_thresh = 0.9; - cls_image_shape = {3, 48, 192}; - cls_batch_num = 1; - mean = {0.5f, 0.5f, 0.5f}; - scale = {0.5f, 0.5f, 0.5f}; - is_scale = true; - if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; @@ -59,85 +50,23 @@ bool Classifier::Initialize() { return true; } -void OcrClassifierResizeImage(Mat* mat, - const std::vector& rec_image_shape) { - int imgC = rec_image_shape[0]; - int imgH = rec_image_shape[1]; - int imgW = rec_image_shape[2]; - - float ratio = float(mat->Width()) / float(mat->Height()); - - int resize_w; - if (ceilf(imgH * ratio) > imgW) - resize_w = imgW; - else - resize_w = int(ceilf(imgH * ratio)); - - Resize::Run(mat, resize_w, imgH); - - std::vector value = {0, 0, 0}; - if (resize_w < imgW) { - Pad::Run(mat, 0, 0, 0, imgW - resize_w, value); +bool Classifier::BatchPredict(const std::vector& images, + std::vector* cls_labels, std::vector* cls_scores) { + std::vector fd_images = WrapMat(images); + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; } -} - -bool Classifier::Preprocess(Mat* mat, FDTensor* output) { - // 1. cls resizes - // 2. normalize - // 3. batch_permute - OcrClassifierResizeImage(mat, cls_image_shape); - - Normalize::Run(mat, mean, scale, true); - - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - - mat->ShareWithTensor(output); - output->shape.insert(output->shape.begin(), 1); - - return true; -} - -bool Classifier::Postprocess(FDTensor& infer_result, - std::tuple* cls_result) { - std::vector output_shape = infer_result.shape; - FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); - - float* out_data = static_cast(infer_result.Data()); - - int label = std::distance( - &out_data[0], std::max_element(&out_data[0], &out_data[output_shape[1]])); - - float score = - float(*std::max_element(&out_data[0], &out_data[output_shape[1]])); - - std::get<0>(*cls_result) = label; - std::get<1>(*cls_result) = score; - - return true; -} - -bool Classifier::Predict(cv::Mat* img, std::tuple* cls_result) { - Mat mat(*img); - std::vector input_tensors(1); - - if (!Preprocess(&mat, &input_tensors[0])) { - FDERROR << "Failed to preprocess input image." << std::endl; + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; return false; } - input_tensors[0].name = InputInfoOfRuntime(0).name; - std::vector output_tensors; - if (!Infer(input_tensors, &output_tensors)) { - FDERROR << "Failed to inference." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, cls_labels, cls_scores)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; return false; } - - if (!Postprocess(output_tensors[0], cls_result)) { - FDERROR << "Failed to post process." << std::endl; - return false; - } - return true; } diff --git a/fastdeploy/vision/ocr/ppocr/classifier.h b/fastdeploy/vision/ocr/ppocr/classifier.h old mode 100644 new mode 100755 index d87fec6fa..d3430e4e0 --- a/fastdeploy/vision/ocr/ppocr/classifier.h +++ b/fastdeploy/vision/ocr/ppocr/classifier.h @@ -17,6 +17,8 @@ #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" +#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h" namespace fastdeploy { namespace vision { @@ -41,29 +43,22 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel { const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name std::string ModelName() const { return "ppocr/ocr_cls"; } - /** \brief Predict the input image and get OCR classification model result. + + /** \brief BatchPredict the input image and get OCR classification model cls_result. * - * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. - * \param[in] result The output of OCR classification model result will be writen to this structure. + * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] cls_results The output of OCR classification model cls_result will be writen to this structure. * \return true if the prediction is successed, otherwise false. */ - virtual bool Predict(cv::Mat* img, std::tuple* result); + virtual bool BatchPredict(const std::vector& images, + std::vector* cls_labels, + std::vector* cls_scores); - // Pre & Post parameters - float cls_thresh; - std::vector cls_image_shape; - int cls_batch_num; - - std::vector mean; - std::vector scale; - bool is_scale; + ClassifierPreprocessor preprocessor_; + ClassifierPostprocessor postprocessor_; private: bool Initialize(); - /// Preprocess the input data, and set the preprocessed results to `outputs` - bool Preprocess(Mat* img, FDTensor* output); - /// Postprocess the inferenced results, and set the final result to `result` - bool Postprocess(FDTensor& infer_result, std::tuple* result); }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc new file mode 100644 index 000000000..5eb6b5d69 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +ClassifierPostprocessor::ClassifierPostprocessor() { + initialized_ = true; +} + +bool SingleBatchPostprocessor(const float* out_data, const size_t& length, int* cls_label, float* cls_score) { + + *cls_label = std::distance( + &out_data[0], std::max_element(&out_data[0], &out_data[length])); + + *cls_score = + float(*std::max_element(&out_data[0], &out_data[length])); + return true; +} + +bool ClassifierPostprocessor::Run(const std::vector& tensors, + std::vector* cls_labels, + std::vector* cls_scores) { + if (!initialized_) { + FDERROR << "Postprocessor is not initialized." << std::endl; + return false; + } + // Classifier have only 1 output tensor. + const FDTensor& tensor = tensors[0]; + + // For Classifier, the output tensor shape = [batch,2] + size_t batch = tensor.shape[0]; + size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies()); + + cls_labels->resize(batch); + cls_scores->resize(batch); + const float* tensor_data = reinterpret_cast(tensor.Data()); + + for (int i_batch = 0; i_batch < batch; ++i_batch) { + if(!SingleBatchPostprocessor(tensor_data, length, &cls_labels->at(i_batch),&cls_scores->at(i_batch))) return false; + tensor_data = tensor_data + length; + } + + return true; +} + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h new file mode 100644 index 000000000..15bf098c7 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/cls_postprocessor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Postprocessor object for Classifier serials model. + */ +class FASTDEPLOY_DECL ClassifierPostprocessor { + public: + /** \brief Create a postprocessor instance for Classifier serials model + * + */ + ClassifierPostprocessor(); + + /** \brief Process the result of runtime and fill to ClassifyResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] cls_labels The output result of classification + * \param[in] cls_scores The output result of classification + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* cls_labels, std::vector* cls_scores); + + float cls_thresh_ = 0.9; + + private: + bool initialized_ = false; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc new file mode 100644 index 000000000..1f0993690 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +ClassifierPreprocessor::ClassifierPreprocessor() { + initialized_ = true; +} + +void OcrClassifierResizeImage(FDMat* mat, + const std::vector& cls_image_shape) { + int imgC = cls_image_shape[0]; + int imgH = cls_image_shape[1]; + int imgW = cls_image_shape[2]; + + float ratio = float(mat->Width()) / float(mat->Height()); + + int resize_w; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + + Resize::Run(mat, resize_w, imgH); + + std::vector value = {0, 0, 0}; + if (resize_w < imgW) { + Pad::Run(mat, 0, 0, 0, imgW - resize_w, value); + } +} + +bool ClassifierPreprocessor::Run(std::vector* images, std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + + for (size_t i = 0; i < images->size(); ++i) { + FDMat* mat = &(images->at(i)); + OcrClassifierResizeImage(mat, cls_image_shape_); + NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); + /* + Normalize::Run(mat, mean_, scale_, is_scale_); + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + */ + } + // Only have 1 output Tensor. + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + (*images)[i].ShareWithTensor(&(tensors[i])); + tensors[i].ExpandDim(0); + } + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h new file mode 100644 index 000000000..a701e7e3a --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/cls_preprocessor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Preprocessor object for Classifier serials model. + */ +class FASTDEPLOY_DECL ClassifierPreprocessor { + public: + /** \brief Create a preprocessor instance for Classifier serials model + * + */ + ClassifierPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs); + + std::vector mean_ = {0.5f, 0.5f, 0.5f}; + std::vector scale_ = {0.5f, 0.5f, 0.5f}; + bool is_scale_ = true; + std::vector cls_image_shape_ = {3, 48, 192}; + + private: + bool initialized_ = false; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.cc b/fastdeploy/vision/ocr/ppocr/dbdetector.cc old mode 100644 new mode 100755 index 8ee44fddf..68a994afc --- a/fastdeploy/vision/ocr/ppocr/dbdetector.cc +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.cc @@ -43,158 +43,50 @@ DBDetector::DBDetector(const std::string& model_file, // Init bool DBDetector::Initialize() { - // pre&post process parameters - max_side_len = 960; - - det_db_thresh = 0.3; - det_db_box_thresh = 0.6; - det_db_unclip_ratio = 1.5; - det_db_score_mode = "slow"; - use_dilation = false; - - mean = {0.485f, 0.456f, 0.406f}; - scale = {0.229f, 0.224f, 0.225f}; - is_scale = true; - if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } - - return true; -} - -void OcrDetectorResizeImage(Mat* img, int max_size_len, float* ratio_h, - float* ratio_w) { - int w = img->Width(); - int h = img->Height(); - - float ratio = 1.f; - int max_wh = w >= h ? w : h; - if (max_wh > max_size_len) { - if (h > w) { - ratio = float(max_size_len) / float(h); - } else { - ratio = float(max_size_len) / float(w); - } - } - - int resize_h = int(float(h) * ratio); - int resize_w = int(float(w) * ratio); - - resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32); - resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32); - - Resize::Run(img, resize_w, resize_h); - - *ratio_h = float(resize_h) / float(h); - *ratio_w = float(resize_w) / float(w); -} - -bool DBDetector::Preprocess( - Mat* mat, FDTensor* output, - std::map>* im_info) { - // Resize - OcrDetectorResizeImage(mat, max_side_len, &ratio_h, &ratio_w); - // Normalize - Normalize::Run(mat, mean, scale, true); - - (*im_info)["output_shape"] = {static_cast(mat->Height()), - static_cast(mat->Width())}; - //-CHW - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - - mat->ShareWithTensor(output); - output->shape.insert(output->shape.begin(), 1); - return true; -} - -bool DBDetector::Postprocess( - FDTensor& infer_result, std::vector>* boxes_result, - const std::map>& im_info) { - std::vector output_shape = infer_result.shape; - FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); - int n2 = output_shape[2]; - int n3 = output_shape[3]; - int n = n2 * n3; - - float* out_data = static_cast(infer_result.Data()); - // prepare bitmap - std::vector pred(n, 0.0); - std::vector cbuf(n, ' '); - - for (int i = 0; i < n; i++) { - pred[i] = float(out_data[i]); - cbuf[i] = (unsigned char)((out_data[i]) * 255); - } - cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data()); - cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data()); - - const double threshold = det_db_thresh * 255; - const double maxvalue = 255; - cv::Mat bit_map; - cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); - if (use_dilation) { - cv::Mat dila_ele = - cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); - cv::dilate(bit_map, bit_map, dila_ele); - } - - std::vector>> boxes; - - boxes = - post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh, - det_db_unclip_ratio, det_db_score_mode); - - boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, im_info); - - // boxes to boxes_result - for (int i = 0; i < boxes.size(); i++) { - std::array new_box; - int k = 0; - for (auto& vec : boxes[i]) { - for (auto& e : vec) { - new_box[k++] = e; - } - } - boxes_result->push_back(new_box); - } - return true; } bool DBDetector::Predict(cv::Mat* img, std::vector>* boxes_result) { - Mat mat(*img); + if (!Predict(*img, boxes_result)) { + return false; + } + return true; +} - std::vector input_tensors(1); +bool DBDetector::Predict(const cv::Mat& img, + std::vector>* boxes_result) { + std::vector>> det_results; + if (!BatchPredict({img}, &det_results)) { + return false; + } + *boxes_result = std::move(det_results[0]); + return true; +} - std::map> im_info; - - // Record the shape of image and the shape of preprocessed image - im_info["input_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - im_info["output_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { +bool DBDetector::BatchPredict(const std::vector& images, + std::vector>>* det_results) { + std::vector fd_images = WrapMat(images); + std::vector> batch_det_img_info; + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &batch_det_img_info)) { FDERROR << "Failed to preprocess input image." << std::endl; return false; } - input_tensors[0].name = InputInfoOfRuntime(0).name; - std::vector output_tensors; - if (!Infer(input_tensors, &output_tensors)) { - FDERROR << "Failed to inference." << std::endl; + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; return false; } - if (!Postprocess(output_tensors[0], boxes_result, im_info)) { - FDERROR << "Failed to post process." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, det_results, batch_det_img_info)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; return false; } - return true; } diff --git a/fastdeploy/vision/ocr/ppocr/dbdetector.h b/fastdeploy/vision/ocr/ppocr/dbdetector.h old mode 100644 new mode 100755 index e0baf319c..d3b99d598 --- a/fastdeploy/vision/ocr/ppocr/dbdetector.h +++ b/fastdeploy/vision/ocr/ppocr/dbdetector.h @@ -17,6 +17,8 @@ #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" +#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h" namespace fastdeploy { namespace vision { @@ -44,40 +46,34 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel { std::string ModelName() const { return "ppocr/ocr_det"; } /** \brief Predict the input image and get OCR detection model result. * - * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] boxes_result The output of OCR detection model result will be writen to this structure. * \return true if the prediction is successed, otherwise false. */ - virtual bool Predict(cv::Mat* im, + virtual bool Predict(cv::Mat* img, std::vector>* boxes_result); + /** \brief Predict the input image and get OCR detection model result. + * + * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] boxes_result The output of OCR detection model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool Predict(const cv::Mat& img, + std::vector>* boxes_result); + /** \brief BatchPredict the input image and get OCR detection model result. + * + * \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] det_results The output of OCR detection model result will be writen to this structure. + * \return true if the prediction is successed, otherwise false. + */ + virtual bool BatchPredict(const std::vector& images, + std::vector>>* det_results); - // Pre & Post process parameters - int max_side_len; - - float ratio_h{}; - float ratio_w{}; - - double det_db_thresh; - double det_db_box_thresh; - double det_db_unclip_ratio; - std::string det_db_score_mode; - bool use_dilation; - - std::vector mean; - std::vector scale; - bool is_scale; + DBDetectorPreprocessor preprocessor_; + DBDetectorPostprocessor postprocessor_; private: bool Initialize(); - /// Preprocess the input data, and set the preprocessed results to `outputs` - bool Preprocess(Mat* mat, FDTensor* outputs, - std::map>* im_info); - /*! @brief Postprocess the inferenced results, and set the final result to `boxes_result` - */ - bool Postprocess(FDTensor& infer_result, - std::vector>* boxes_result, - const std::map>& im_info); - PostProcessor post_processor_; }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc new file mode 100644 index 000000000..34a88c011 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/det_postprocessor.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +DBDetectorPostprocessor::DBDetectorPostprocessor() { + initialized_ = true; +} + +bool DBDetectorPostprocessor::SingleBatchPostprocessor( + const float* out_data, + int n2, + int n3, + const std::array& det_img_info, + std::vector>* boxes_result + ) { + int n = n2 * n3; + + // prepare bitmap + std::vector pred(n, 0.0); + std::vector cbuf(n, ' '); + + for (int i = 0; i < n; i++) { + pred[i] = float(out_data[i]); + cbuf[i] = (unsigned char)((out_data[i]) * 255); + } + cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char*)cbuf.data()); + cv::Mat pred_map(n2, n3, CV_32F, (float*)pred.data()); + + const double threshold = det_db_thresh_ * 255; + const double maxvalue = 255; + cv::Mat bit_map; + cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); + if (use_dilation_) { + cv::Mat dila_ele = + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); + cv::dilate(bit_map, bit_map, dila_ele); + } + + std::vector>> boxes; + + boxes = + post_processor_.BoxesFromBitmap(pred_map, bit_map, det_db_box_thresh_, + det_db_unclip_ratio_, det_db_score_mode_); + + boxes = post_processor_.FilterTagDetRes(boxes, det_img_info); + + // boxes to boxes_result + for (int i = 0; i < boxes.size(); i++) { + std::array new_box; + int k = 0; + for (auto& vec : boxes[i]) { + for (auto& e : vec) { + new_box[k++] = e; + } + } + boxes_result->push_back(new_box); + } + + return true; +} + +bool DBDetectorPostprocessor::Run(const std::vector& tensors, + std::vector>>* results, + const std::vector>& batch_det_img_info) { + if (!initialized_) { + FDERROR << "Postprocessor is not initialized." << std::endl; + return false; + } + // DBDetector have only 1 output tensor. + const FDTensor& tensor = tensors[0]; + + // For DBDetector, the output tensor shape = [batch, 1, ?, ?] + size_t batch = tensor.shape[0]; + size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies()); + const float* tensor_data = reinterpret_cast(tensor.Data()); + + results->resize(batch); + for (int i_batch = 0; i_batch < batch; ++i_batch) { + if(!SingleBatchPostprocessor(tensor_data, + tensor.shape[2], + tensor.shape[3], + batch_det_img_info[i_batch], + &results->at(i_batch) + ))return false; + tensor_data = tensor_data + length; + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/det_postprocessor.h b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h new file mode 100644 index 000000000..f98b89b02 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/det_postprocessor.h @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Postprocessor object for DBDetector serials model. + */ +class FASTDEPLOY_DECL DBDetectorPostprocessor { + public: + /** \brief Create a postprocessor instance for DBDetector serials model + * + */ + DBDetectorPostprocessor(); + + /** \brief Process the result of runtime and fill to results structure + * + * \param[in] tensors The inference result from runtime + * \param[in] results The output result of detector + * \param[in] batch_det_img_info The detector_preprocess result + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector>>* results, + const std::vector>& batch_det_img_info); + + double det_db_thresh_ = 0.3; + double det_db_box_thresh_ = 0.6; + double det_db_unclip_ratio_ = 1.5; + std::string det_db_score_mode_ = "slow"; + bool use_dilation_ = false; + + private: + bool initialized_ = false; + PostProcessor post_processor_; + bool SingleBatchPostprocessor(const float* out_data, + int n2, + int n3, + const std::array& det_img_info, + std::vector>* boxes_result); +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc new file mode 100644 index 000000000..89a8d6d39 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +DBDetectorPreprocessor::DBDetectorPreprocessor() { + initialized_ = true; +} + +std::array OcrDetectorGetInfo(FDMat* img, int max_size_len) { + int w = img->Width(); + int h = img->Height(); + + float ratio = 1.f; + int max_wh = w >= h ? w : h; + if (max_wh > max_size_len) { + if (h > w) { + ratio = float(max_size_len) / float(h); + } else { + ratio = float(max_size_len) / float(w); + } + } + int resize_h = int(float(h) * ratio); + int resize_w = int(float(w) * ratio); + resize_h = std::max(int(std::round(float(resize_h) / 32) * 32), 32); + resize_w = std::max(int(std::round(float(resize_w) / 32) * 32), 32); + + return {w,h,resize_w,resize_h}; + /* + *ratio_h = float(resize_h) / float(h); + *ratio_w = float(resize_w) / float(w); + */ +} +bool OcrDetectorResizeImage(FDMat* img, + int resize_w, + int resize_h, + int max_resize_w, + int max_resize_h) { + Resize::Run(img, resize_w, resize_h); + std::vector value = {0, 0, 0}; + Pad::Run(img, 0, max_resize_h-resize_h, 0, max_resize_w - resize_w, value); + return true; +} + +bool DBDetectorPreprocessor::Run(std::vector* images, + std::vector* outputs, + std::vector>* batch_det_img_info_ptr) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + int max_resize_w = 0; + int max_resize_h = 0; + std::vector>& batch_det_img_info = *batch_det_img_info_ptr; + batch_det_img_info.clear(); + batch_det_img_info.resize(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + FDMat* mat = &(images->at(i)); + batch_det_img_info[i] = OcrDetectorGetInfo(mat,max_side_len_); + max_resize_w = std::max(max_resize_w,batch_det_img_info[i][2]); + max_resize_h = std::max(max_resize_h,batch_det_img_info[i][3]); + } + for (size_t i = 0; i < images->size(); ++i) { + FDMat* mat = &(images->at(i)); + OcrDetectorResizeImage(mat, batch_det_img_info[i][2],batch_det_img_info[i][3],max_resize_w,max_resize_h); + NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); + /* + Normalize::Run(mat, mean_, scale_, is_scale_); + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + */ + } + // Only have 1 output Tensor. + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + (*images)[i].ShareWithTensor(&(tensors[i])); + tensors[i].ExpandDim(0); + } + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/det_preprocessor.h b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h new file mode 100644 index 000000000..39c48691d --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/det_preprocessor.h @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Preprocessor object for DBDetector serials model. + */ +class FASTDEPLOY_DECL DBDetectorPreprocessor { + public: + /** \brief Create a preprocessor instance for DBDetector serials model + * + */ + DBDetectorPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \param[in] batch_det_img_info_ptr The output of preprocess + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, + std::vector* outputs, + std::vector>* batch_det_img_info_ptr); + + int max_side_len_ = 960; + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {0.229f, 0.224f, 0.225f}; + bool is_scale_ = true; + + private: + bool initialized_ = false; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc old mode 100644 new mode 100755 index 2e4e3e85b..96c3e177e --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -16,45 +16,171 @@ namespace fastdeploy { void BindPPOCRModel(pybind11::module& m) { + m.def("sort_boxes", [](std::vector>& boxes) { + vision::ocr::SortBoxes(&boxes); + return boxes; + }); // DBDetector pybind11::class_(m, "DBDetector") .def(pybind11::init()) .def(pybind11::init<>()) + .def_readwrite("preprocessor", &vision::ocr::DBDetector::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::DBDetector::postprocessor_); - .def_readwrite("max_side_len", &vision::ocr::DBDetector::max_side_len) - .def_readwrite("det_db_thresh", &vision::ocr::DBDetector::det_db_thresh) - .def_readwrite("det_db_box_thresh", - &vision::ocr::DBDetector::det_db_box_thresh) - .def_readwrite("det_db_unclip_ratio", - &vision::ocr::DBDetector::det_db_unclip_ratio) - .def_readwrite("det_db_score_mode", - &vision::ocr::DBDetector::det_db_score_mode) - .def_readwrite("use_dilation", &vision::ocr::DBDetector::use_dilation) - .def_readwrite("mean", &vision::ocr::DBDetector::mean) - .def_readwrite("scale", &vision::ocr::DBDetector::scale) - .def_readwrite("is_scale", &vision::ocr::DBDetector::is_scale); + pybind11::class_(m, "DBDetectorPreprocessor") + .def(pybind11::init<>()) + .def_readwrite("max_side_len", &vision::ocr::DBDetectorPreprocessor::max_side_len_) + .def_readwrite("mean", &vision::ocr::DBDetectorPreprocessor::mean_) + .def_readwrite("scale", &vision::ocr::DBDetectorPreprocessor::scale_) + .def_readwrite("is_scale", &vision::ocr::DBDetectorPreprocessor::is_scale_) + .def("run", [](vision::ocr::DBDetectorPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + std::vector> batch_det_img_info; + self.Run(&images, &outputs, &batch_det_img_info); + for(size_t i = 0; i< outputs.size(); ++i){ + outputs[i].StopSharing(); + } + return make_pair(outputs, batch_det_img_info); + }); + + pybind11::class_(m, "DBDetectorPostprocessor") + .def(pybind11::init<>()) + .def_readwrite("det_db_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_thresh_) + .def_readwrite("det_db_box_thresh", &vision::ocr::DBDetectorPostprocessor::det_db_box_thresh_) + .def_readwrite("det_db_unclip_ratio", &vision::ocr::DBDetectorPostprocessor::det_db_unclip_ratio_) + .def_readwrite("det_db_score_mode", &vision::ocr::DBDetectorPostprocessor::det_db_score_mode_) + .def_readwrite("use_dilation", &vision::ocr::DBDetectorPostprocessor::use_dilation_) + .def("run", [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& inputs, + const std::vector>& batch_det_img_info) { + std::vector>> results; + + if (!self.Run(inputs, &results, batch_det_img_info)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')"); + } + return results; + }) + .def("run", [](vision::ocr::DBDetectorPostprocessor& self, + std::vector& input_array, + const std::vector>& batch_det_img_info) { + std::vector>> results; + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + if (!self.Run(inputs, &results, batch_det_img_info)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in DBDetectorPostprocessor.')"); + } + return results; + }); // Classifier pybind11::class_(m, "Classifier") .def(pybind11::init()) .def(pybind11::init<>()) + .def_readwrite("preprocessor", &vision::ocr::Classifier::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::Classifier::postprocessor_); + + pybind11::class_(m, "ClassifierPreprocessor") + .def(pybind11::init<>()) + .def_readwrite("cls_image_shape", &vision::ocr::ClassifierPreprocessor::cls_image_shape_) + .def_readwrite("mean", &vision::ocr::ClassifierPreprocessor::mean_) + .def_readwrite("scale", &vision::ocr::ClassifierPreprocessor::scale_) + .def_readwrite("is_scale", &vision::ocr::ClassifierPreprocessor::is_scale_) + .def("run", [](vision::ocr::ClassifierPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPreprocessor.')"); + } + for(size_t i = 0; i< outputs.size(); ++i){ + outputs[i].StopSharing(); + } + return outputs; + }); + + pybind11::class_(m, "ClassifierPostprocessor") + .def(pybind11::init<>()) + .def_readwrite("cls_thresh", &vision::ocr::ClassifierPostprocessor::cls_thresh_) + .def("run", [](vision::ocr::ClassifierPostprocessor& self, + std::vector& inputs) { + std::vector cls_labels; + std::vector cls_scores; + if (!self.Run(inputs, &cls_labels, &cls_scores)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')"); + } + return make_pair(cls_labels,cls_scores); + }) + .def("run", [](vision::ocr::ClassifierPostprocessor& self, + std::vector& input_array) { + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + std::vector cls_labels; + std::vector cls_scores; + if (!self.Run(inputs, &cls_labels, &cls_scores)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in ClassifierPostprocessor.')"); + } + return make_pair(cls_labels,cls_scores); + }); - .def_readwrite("cls_thresh", &vision::ocr::Classifier::cls_thresh) - .def_readwrite("cls_image_shape", - &vision::ocr::Classifier::cls_image_shape) - .def_readwrite("cls_batch_num", &vision::ocr::Classifier::cls_batch_num); // Recognizer pybind11::class_(m, "Recognizer") - .def(pybind11::init()) .def(pybind11::init<>()) + .def_readwrite("preprocessor", &vision::ocr::Recognizer::preprocessor_) + .def_readwrite("postprocessor", &vision::ocr::Recognizer::postprocessor_); - .def_readwrite("rec_img_h", &vision::ocr::Recognizer::rec_img_h) - .def_readwrite("rec_img_w", &vision::ocr::Recognizer::rec_img_w) - .def_readwrite("rec_batch_num", &vision::ocr::Recognizer::rec_batch_num); + pybind11::class_(m, "RecognizerPreprocessor") + .def(pybind11::init<>()) + .def_readwrite("rec_image_shape", &vision::ocr::RecognizerPreprocessor::rec_image_shape_) + .def_readwrite("mean", &vision::ocr::RecognizerPreprocessor::mean_) + .def_readwrite("scale", &vision::ocr::RecognizerPreprocessor::scale_) + .def_readwrite("is_scale", &vision::ocr::RecognizerPreprocessor::is_scale_) + .def("run", [](vision::ocr::RecognizerPreprocessor& self, std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPreprocessor.')"); + } + for(size_t i = 0; i< outputs.size(); ++i){ + outputs[i].StopSharing(); + } + return outputs; + }); + + pybind11::class_(m, "RecognizerPostprocessor") + .def(pybind11::init()) + .def("run", [](vision::ocr::RecognizerPostprocessor& self, + std::vector& inputs) { + std::vector texts; + std::vector rec_scores; + if (!self.Run(inputs, &texts, &rec_scores)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')"); + } + return make_pair(texts, rec_scores); + }) + .def("run", [](vision::ocr::RecognizerPostprocessor& self, + std::vector& input_array) { + std::vector inputs; + PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true); + std::vector texts; + std::vector rec_scores; + if (!self.Run(inputs, &texts, &rec_scores)) { + pybind11::eval("raise Exception('Failed to preprocess the input data in RecognizerPostprocessor.')"); + } + return make_pair(texts, rec_scores); + }); } } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc b/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc old mode 100644 new mode 100755 index a88ae2fc7..fcbbe0224 --- a/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc @@ -31,6 +31,15 @@ void BindPPOCRv3(pybind11::module& m) { vision::OCRResult res; self.Predict(&mat, &res); return res; + }) + .def("batch_predict", [](pipeline::PPOCRv3& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; }); } @@ -49,6 +58,15 @@ void BindPPOCRv2(pybind11::module& m) { vision::OCRResult res; self.Predict(&mat, &res); return res; + }) + .def("batch_predict", [](pipeline::PPOCRv2& self, std::vector& data) { + std::vector images; + for (size_t i = 0; i < data.size(); ++i) { + images.push_back(PyArrayToCvMat(data[i])); + } + std::vector results; + self.BatchPredict(images, &results); + return results; }); } diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc old mode 100644 new mode 100755 index 7ad1f105c..e6e89299b --- a/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v2.cc @@ -22,101 +22,95 @@ PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, fastdeploy::vision::ocr::Classifier* cls_model, fastdeploy::vision::ocr::Recognizer* rec_model) : detector_(det_model), classifier_(cls_model), recognizer_(rec_model) { - recognizer_->rec_image_shape[1] = 32; + Initialized(); + recognizer_->preprocessor_.rec_image_shape_[1] = 32; } PPOCRv2::PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, fastdeploy::vision::ocr::Recognizer* rec_model) : detector_(det_model), recognizer_(rec_model) { - recognizer_->rec_image_shape[1] = 32; + Initialized(); + recognizer_->preprocessor_.rec_image_shape_[1] = 32; } bool PPOCRv2::Initialized() const { - if (detector_ != nullptr && !detector_->Initialized()){ + if (detector_ != nullptr && !detector_->Initialized()) { return false; } - if (classifier_ != nullptr && !classifier_->Initialized()){ + if (classifier_ != nullptr && !classifier_->Initialized()) { return false; } - if (recognizer_ != nullptr && !recognizer_->Initialized()){ + if (recognizer_ != nullptr && !recognizer_->Initialized()) { return false; } return true; } -bool PPOCRv2::Detect(cv::Mat* img, - fastdeploy::vision::OCRResult* result) { - if (!detector_->Predict(img, &(result->boxes))) { - FDERROR << "There's error while detecting image in PPOCR." << std::endl; - return false; - } - vision::ocr::SortBoxes(result); - return true; -} - -bool PPOCRv2::Recognize(cv::Mat* img, - fastdeploy::vision::OCRResult* result) { - std::tuple rec_result; - if (!recognizer_->Predict(img, &rec_result)) { - FDERROR << "There's error while recognizing image in PPOCR." << std::endl; - return false; - } - - result->text.push_back(std::get<0>(rec_result)); - result->rec_scores.push_back(std::get<1>(rec_result)); - return true; -} - -bool PPOCRv2::Classify(cv::Mat* img, - fastdeploy::vision::OCRResult* result) { - std::tuple cls_result; - - if (!classifier_->Predict(img, &cls_result)) { - FDERROR << "There's error while classifying image in PPOCR." << std::endl; - return false; - } - - result->cls_labels.push_back(std::get<0>(cls_result)); - result->cls_scores.push_back(std::get<1>(cls_result)); - return true; -} - bool PPOCRv2::Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result) { - result->Clear(); - if (nullptr != detector_ && !Detect(img, result)) { - FDERROR << "Failed to detect image." << std::endl; - return false; - } - - // Get croped images by detection result - std::vector image_list; - for (size_t i = 0; i < result->boxes.size(); ++i) { - auto crop_im = vision::ocr::GetRotateCropImage(*img, (result->boxes)[i]); - image_list.push_back(crop_im); - } - if (result->boxes.size() == 0) { - image_list.push_back(*img); - } - - for (size_t i = 0; i < image_list.size(); ++i) { - if (nullptr != classifier_ && !Classify(&(image_list[i]), result)) { - FDERROR << "Failed to classify croped image of index " << i << "." << std::endl; - return false; - } - if (nullptr != classifier_ && result->cls_labels[i] % 2 == 1 && result->cls_scores[i] > classifier_->cls_thresh) { - cv::rotate(image_list[i], image_list[i], 1); - } - if (nullptr != recognizer_ && !Recognize(&(image_list[i]), result)) { - FDERROR << "Failed to recgnize croped image of index " << i << "." << std::endl; - return false; - } - } + std::vector batch_result(1); + BatchPredict({*img},&batch_result); + *result = std::move(batch_result[0]); return true; }; +bool PPOCRv2::BatchPredict(const std::vector& images, + std::vector* batch_result) { + batch_result->clear(); + batch_result->resize(images.size()); + std::vector>> batch_boxes(images.size()); + + if (!detector_->BatchPredict(images, &batch_boxes)) { + FDERROR << "There's error while detecting image in PPOCR." << std::endl; + return false; + } + for(int i_batch = 0; i_batch < batch_boxes.size(); ++i_batch) { + vision::ocr::SortBoxes(&(batch_boxes[i_batch])); + (*batch_result)[i_batch].boxes = batch_boxes[i_batch]; + } + + + for(int i_batch = 0; i_batch < images.size(); ++i_batch) { + fastdeploy::vision::OCRResult& ocr_result = (*batch_result)[i_batch]; + // Get croped images by detection result + const std::vector>& boxes = ocr_result.boxes; + const cv::Mat& img = images[i_batch]; + std::vector image_list; + if (boxes.size() == 0) { + image_list.emplace_back(img); + }else{ + image_list.resize(boxes.size()); + for (size_t i_box = 0; i_box < boxes.size(); ++i_box) { + image_list[i_box] = vision::ocr::GetRotateCropImage(img, boxes[i_box]); + } + } + std::vector* cls_labels_ptr = &ocr_result.cls_labels; + std::vector* cls_scores_ptr = &ocr_result.cls_scores; + + std::vector* text_ptr = &ocr_result.text; + std::vector* rec_scores_ptr = &ocr_result.rec_scores; + + if (!classifier_->BatchPredict(image_list, cls_labels_ptr, cls_scores_ptr)) { + FDERROR << "There's error while recognizing image in PPOCR." << std::endl; + return false; + }else{ + for (size_t i_img = 0; i_img < image_list.size(); ++i_img) { + if(cls_labels_ptr->at(i_img) % 2 == 1 && cls_scores_ptr->at(i_img) > classifier_->postprocessor_.cls_thresh_) { + cv::rotate(image_list[i_img], image_list[i_img], 1); + } + } + } + + if (!recognizer_->BatchPredict(image_list, text_ptr, rec_scores_ptr)) { + FDERROR << "There's error while recognizing image in PPOCR." << std::endl; + return false; + } + } + return true; +} + } // namesapce pipeline } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v2.h b/fastdeploy/vision/ocr/ppocr/ppocr_v2.h old mode 100644 new mode 100755 index bf5300020..d021d6c32 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v2.h +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v2.h @@ -59,6 +59,14 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel { * \return true if the prediction successed, otherwise false. */ virtual bool Predict(cv::Mat* img, fastdeploy::vision::OCRResult* result); + /** \brief BatchPredict the input image and get OCR result. + * + * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] batch_result The output list of OCR result will be writen to this structure. + * \return true if the prediction successed, otherwise false. + */ + virtual bool BatchPredict(const std::vector& images, + std::vector* batch_result); bool Initialized() const override; protected: @@ -66,11 +74,6 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel { fastdeploy::vision::ocr::Classifier* classifier_ = nullptr; fastdeploy::vision::ocr::Recognizer* recognizer_ = nullptr; /// Launch the detection process in OCR. - virtual bool Detect(cv::Mat* img, fastdeploy::vision::OCRResult* result); - /// Launch the recognition process in OCR. - virtual bool Recognize(cv::Mat* img, fastdeploy::vision::OCRResult* result); - /// Launch the classification process in OCR. - virtual bool Classify(cv::Mat* img, fastdeploy::vision::OCRResult* result); }; namespace application { diff --git a/fastdeploy/vision/ocr/ppocr/ppocr_v3.h b/fastdeploy/vision/ocr/ppocr/ppocr_v3.h old mode 100644 new mode 100755 index e248eca75..ed9177d92 --- a/fastdeploy/vision/ocr/ppocr/ppocr_v3.h +++ b/fastdeploy/vision/ocr/ppocr/ppocr_v3.h @@ -36,7 +36,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 { fastdeploy::vision::ocr::Recognizer* rec_model) : PPOCRv2(det_model, cls_model, rec_model) { // The only difference between v2 and v3 - recognizer_->rec_image_shape[1] = 48; + recognizer_->preprocessor_.rec_image_shape_[1] = 48; } /** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively. * @@ -47,7 +47,7 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 { fastdeploy::vision::ocr::Recognizer* rec_model) : PPOCRv2(det_model, rec_model) { // The only difference between v2 and v3 - recognizer_->rec_image_shape[1] = 48; + recognizer_->preprocessor_.rec_image_shape_[1] = 48; } }; diff --git a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc new file mode 100644 index 000000000..cdc302e28 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +std::vector ReadDict(const std::string& path) { + std::ifstream in(path); + FDASSERT(in, "Cannot open file %s to read.", path.c_str()); + std::string line; + std::vector m_vec; + while (getline(in, line)) { + m_vec.push_back(line); + } + m_vec.insert(m_vec.begin(), "#"); // blank char for ctc + m_vec.push_back(" "); + return m_vec; +} + +RecognizerPostprocessor::RecognizerPostprocessor(){ + initialized_ = true; +} + +RecognizerPostprocessor::RecognizerPostprocessor(const std::string& label_path) { + // init label_lsit + label_list_ = ReadDict(label_path); + initialized_ = true; +} + +bool RecognizerPostprocessor::SingleBatchPostprocessor(const float* out_data, + const std::vector& output_shape, + std::string* text, float* rec_score) { + std::string& str_res = *text; + float& score = *rec_score; + score = 0.f; + int argmax_idx; + int last_index = 0; + int count = 0; + float max_value = 0.0f; + + for (int n = 0; n < output_shape[1]; n++) { + argmax_idx = int( + std::distance(&out_data[n * output_shape[2]], + std::max_element(&out_data[n * output_shape[2]], + &out_data[(n + 1) * output_shape[2]]))); + + max_value = float(*std::max_element(&out_data[n * output_shape[2]], + &out_data[(n + 1) * output_shape[2]])); + + if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { + score += max_value; + count += 1; + if(argmax_idx > label_list_.size()) { + FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: " + << label_list_.size() << ". Please check the label file!" << std::endl; + return false; + } + str_res += label_list_[argmax_idx]; + } + last_index = argmax_idx; + } + score /= (count + 1e-6); + if (count == 0 || std::isnan(score)) { + score = 0.f; + } + return true; +} + +bool RecognizerPostprocessor::Run(const std::vector& tensors, + std::vector* texts, std::vector* rec_scores) { + if (!initialized_) { + FDERROR << "Postprocessor is not initialized." << std::endl; + return false; + } + // Recognizer have only 1 output tensor. + const FDTensor& tensor = tensors[0]; + // For Recognizer, the output tensor shape = [batch, ?, 6625] + size_t batch = tensor.shape[0]; + size_t length = accumulate(tensor.shape.begin()+1, tensor.shape.end(), 1, std::multiplies()); + + texts->resize(batch); + rec_scores->resize(batch); + const float* tensor_data = reinterpret_cast(tensor.Data()); + for (int i_batch = 0; i_batch < batch; ++i_batch) { + if(!SingleBatchPostprocessor(tensor_data, tensor.shape, &texts->at(i_batch), &rec_scores->at(i_batch))) { + return false; + } + tensor_data = tensor_data + length; + } + + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h new file mode 100644 index 000000000..d1aa0124b --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/rec_postprocessor.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Postprocessor object for Recognizer serials model. + */ +class FASTDEPLOY_DECL RecognizerPostprocessor { + public: + RecognizerPostprocessor(); + /** \brief Create a postprocessor instance for Recognizer serials model + * + * \param[in] label_path The path of label_dict + */ + explicit RecognizerPostprocessor(const std::string& label_path); + + /** \brief Process the result of runtime and fill to ClassifyResult structure + * + * \param[in] tensors The inference result from runtime + * \param[in] texts The output result of recognizer + * \param[in] rec_scores The output result of recognizer + * \return true if the postprocess successed, otherwise false + */ + bool Run(const std::vector& tensors, + std::vector* texts, std::vector* rec_scores); + + private: + bool SingleBatchPostprocessor(const float* out_data, + const std::vector& output_shape, + std::string* text, float* rec_score); + bool initialized_ = false; + std::vector label_list_; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc new file mode 100644 index 000000000..858578d69 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h" +#include "fastdeploy/function/concat.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +RecognizerPreprocessor::RecognizerPreprocessor() { + initialized_ = true; +} + +void OcrRecognizerResizeImage(FDMat* mat, float max_wh_ratio, + const std::vector& rec_image_shape) { + int imgC, imgH, imgW; + imgC = rec_image_shape[0]; + imgH = rec_image_shape[1]; + imgW = rec_image_shape[2]; + + imgW = int(imgH * max_wh_ratio); + + float ratio = float(mat->Width()) / float(mat->Height()); + int resize_w; + if (ceilf(imgH * ratio) > imgW) { + resize_w = imgW; + }else{ + resize_w = int(ceilf(imgH * ratio)); + } + Resize::Run(mat, resize_w, imgH); + + std::vector value = {0, 0, 0}; + Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value); +} + +bool RecognizerPreprocessor::Run(std::vector* images, std::vector* outputs) { + if (!initialized_) { + FDERROR << "The preprocessor is not initialized." << std::endl; + return false; + } + if (images->size() == 0) { + FDERROR << "The size of input images should be greater than 0." << std::endl; + return false; + } + + int imgH = rec_image_shape_[1]; + int imgW = rec_image_shape_[2]; + float max_wh_ratio = imgW * 1.0 / imgH; + float ori_wh_ratio; + + for (size_t i = 0; i < images->size(); ++i) { + FDMat* mat = &(images->at(i)); + ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); + max_wh_ratio = std::max(max_wh_ratio, ori_wh_ratio); + } + + for (size_t i = 0; i < images->size(); ++i) { + FDMat* mat = &(images->at(i)); + OcrRecognizerResizeImage(mat, max_wh_ratio, rec_image_shape_); + NormalizeAndPermute::Run(mat, mean_, scale_, is_scale_); + /* + Normalize::Run(mat, mean_, scale_, is_scale_); + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + */ + } + // Only have 1 output Tensor. + outputs->resize(1); + // Concat all the preprocessed data to a batch tensor + std::vector tensors(images->size()); + for (size_t i = 0; i < images->size(); ++i) { + (*images)[i].ShareWithTensor(&(tensors[i])); + tensors[i].ExpandDim(0); + } + if (tensors.size() == 1) { + (*outputs)[0] = std::move(tensors[0]); + } else { + function::Concat(tensors, &((*outputs)[0]), 0); + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h new file mode 100644 index 000000000..3e5c7de82 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/rec_preprocessor.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { + +namespace ocr { +/*! @brief Preprocessor object for PaddleClas serials model. + */ +class FASTDEPLOY_DECL RecognizerPreprocessor { + public: + /** \brief Create a preprocessor instance for PaddleClas serials model + * + * \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml + */ + RecognizerPreprocessor(); + + /** \brief Process the input image and prepare input tensors for runtime + * + * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] outputs The output tensors which will feed in runtime + * \return true if the preprocess successed, otherwise false + */ + bool Run(std::vector* images, std::vector* outputs); + + std::vector rec_image_shape_ = {3, 48, 320}; + std::vector mean_ = {0.5f, 0.5f, 0.5f}; + std::vector scale_ = {0.5f, 0.5f, 0.5f}; + bool is_scale_ = true; + + private: + bool initialized_ = false; +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.cc b/fastdeploy/vision/ocr/ppocr/recognizer.cc old mode 100644 new mode 100755 index f0564ce33..59d64a6c4 --- a/fastdeploy/vision/ocr/ppocr/recognizer.cc +++ b/fastdeploy/vision/ocr/ppocr/recognizer.cc @@ -20,29 +20,13 @@ namespace fastdeploy { namespace vision { namespace ocr { -std::vector ReadDict(const std::string& path) { - std::ifstream in(path); - std::string line; - std::vector m_vec; - if (in) { - while (getline(in, line)) { - m_vec.push_back(line); - } - } else { - std::cout << "no such label file: " << path << ", exit the program..." - << std::endl; - exit(1); - } - return m_vec; -} - Recognizer::Recognizer() {} Recognizer::Recognizer(const std::string& model_file, const std::string& params_file, const std::string& label_path, const RuntimeOption& custom_option, - const ModelFormat& model_format) { + const ModelFormat& model_format):postprocessor_(label_path) { if (model_format == ModelFormat::ONNX) { valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; @@ -56,27 +40,11 @@ Recognizer::Recognizer(const std::string& model_file, runtime_option.model_format = model_format; runtime_option.model_file = model_file; runtime_option.params_file = params_file; - initialized = Initialize(); - - // init label_lsit - label_list = ReadDict(label_path); - label_list.insert(label_list.begin(), "#"); // blank char for ctc - label_list.push_back(" "); } // Init bool Recognizer::Initialize() { - // pre&post process parameters - rec_batch_num = 1; - rec_img_h = 48; - rec_img_w = 320; - rec_image_shape = {3, rec_img_h, rec_img_w}; - - mean = {0.5f, 0.5f, 0.5f}; - scale = {0.5f, 0.5f, 0.5f}; - is_scale = true; - if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; @@ -85,119 +53,23 @@ bool Recognizer::Initialize() { return true; } -void OcrRecognizerResizeImage(Mat* mat, const float& wh_ratio, - const std::vector& rec_image_shape) { - int imgC, imgH, imgW; - imgC = rec_image_shape[0]; - imgH = rec_image_shape[1]; - imgW = rec_image_shape[2]; - - imgW = int(imgH * wh_ratio); - - float ratio = float(mat->Width()) / float(mat->Height()); - int resize_w; - if (ceilf(imgH * ratio) > imgW) - resize_w = imgW; - else - resize_w = int(ceilf(imgH * ratio)); - - Resize::Run(mat, resize_w, imgH); - - std::vector value = {127, 127, 127}; - Pad::Run(mat, 0, 0, 0, int(imgW - mat->Width()), value); -} - -bool Recognizer::Preprocess(Mat* mat, FDTensor* output, - const std::vector& rec_image_shape) { - int imgH = rec_image_shape[1]; - int imgW = rec_image_shape[2]; - float wh_ratio = imgW * 1.0 / imgH; - - float ori_wh_ratio = mat->Width() * 1.0 / mat->Height(); - wh_ratio = std::max(wh_ratio, ori_wh_ratio); - - OcrRecognizerResizeImage(mat, wh_ratio, rec_image_shape); - - Normalize::Run(mat, mean, scale, true); - - HWC2CHW::Run(mat); - Cast::Run(mat, "float"); - - mat->ShareWithTensor(output); - output->shape.insert(output->shape.begin(), 1); - - return true; -} - -bool Recognizer::Postprocess(FDTensor& infer_result, - std::tuple* rec_result) { - std::vector output_shape = infer_result.shape; - FDASSERT(output_shape[0] == 1, "Only support batch =1 now."); - - float* out_data = static_cast(infer_result.Data()); - - std::string str_res; - int argmax_idx; - int last_index = 0; - float score = 0.f; - int count = 0; - float max_value = 0.0f; - - for (int n = 0; n < output_shape[1]; n++) { - argmax_idx = int( - std::distance(&out_data[n * output_shape[2]], - std::max_element(&out_data[n * output_shape[2]], - &out_data[(n + 1) * output_shape[2]]))); - - max_value = float(*std::max_element(&out_data[n * output_shape[2]], - &out_data[(n + 1) * output_shape[2]])); - - if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { - score += max_value; - count += 1; - if(argmax_idx > label_list.size()){ - FDERROR << "The output index: " << argmax_idx << " is larger than the size of label_list: " - << label_list.size() << ". Please check the label file!" << std::endl; - return false; - } - str_res += label_list[argmax_idx]; - } - last_index = argmax_idx; +bool Recognizer::BatchPredict(const std::vector& images, + std::vector* texts, std::vector* rec_scores) { + std::vector fd_images = WrapMat(images); + if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { + FDERROR << "Failed to preprocess the input image." << std::endl; + return false; } - score /= (count + 1e-6); - if (count == 0 || std::isnan(score)) { - score = 0.f; - } - std::get<0>(*rec_result) = str_res; - std::get<1>(*rec_result) = score; - - return true; -} - -bool Recognizer::Predict(cv::Mat* img, - std::tuple* rec_result) { - Mat mat(*img); - - std::vector input_tensors(1); - - if (!Preprocess(&mat, &input_tensors[0], rec_image_shape)) { - FDERROR << "Failed to preprocess input image." << std::endl; + reused_input_tensors_[0].name = InputInfoOfRuntime(0).name; + if (!Infer(reused_input_tensors_, &reused_output_tensors_)) { + FDERROR << "Failed to inference by runtime." << std::endl; return false; } - input_tensors[0].name = InputInfoOfRuntime(0).name; - std::vector output_tensors; - - if (!Infer(input_tensors, &output_tensors)) { - FDERROR << "Failed to inference." << std::endl; + if (!postprocessor_.Run(reused_output_tensors_, texts, rec_scores)) { + FDERROR << "Failed to postprocess the inference cls_results by runtime." << std::endl; return false; } - - if (!Postprocess(output_tensors[0], rec_result)) { - FDERROR << "Failed to post process." << std::endl; - return false; - } - return true; } diff --git a/fastdeploy/vision/ocr/ppocr/recognizer.h b/fastdeploy/vision/ocr/ppocr/recognizer.h old mode 100644 new mode 100755 index d3c5fcc9d..1cd841eb4 --- a/fastdeploy/vision/ocr/ppocr/recognizer.h +++ b/fastdeploy/vision/ocr/ppocr/recognizer.h @@ -17,6 +17,8 @@ #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" +#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h" +#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h" namespace fastdeploy { namespace vision { @@ -43,35 +45,20 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel { const ModelFormat& model_format = ModelFormat::PADDLE); /// Get model's name std::string ModelName() const { return "ppocr/ocr_rec"; } - /** \brief Predict the input image and get OCR recognition model result. + /** \brief BatchPredict the input image and get OCR recognition model result. * - * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. - * \param[in] rec_result The output of OCR recognition model result will be writen to this structure. + * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. + * \param[in] rec_results The output of OCR recognition model result will be writen to this structure. * \return true if the prediction is successed, otherwise false. */ - virtual bool Predict(cv::Mat* img, - std::tuple* rec_result); + virtual bool BatchPredict(const std::vector& images, + std::vector* texts, std::vector* rec_scores); - // Pre & Post parameters - std::vector label_list; - int rec_batch_num; - int rec_img_h; - int rec_img_w; - std::vector rec_image_shape; - - std::vector mean; - std::vector scale; - bool is_scale; + RecognizerPreprocessor preprocessor_; + RecognizerPostprocessor postprocessor_; private: bool Initialize(); - /// Preprocess the input data, and set the preprocessed results to `outputs` - bool Preprocess(Mat* img, FDTensor* outputs, - const std::vector& rec_image_shape); - /*! @brief Postprocess the inferenced results, and set the final result to `rec_result` - */ - bool Postprocess(FDTensor& infer_result, - std::tuple* rec_result); }; } // namespace ocr diff --git a/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc b/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc old mode 100644 new mode 100755 index 02e435f76..7a8f387e2 --- a/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc +++ b/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.cc @@ -318,10 +318,12 @@ std::vector>> PostProcessor::BoxesFromBitmap( } std::vector>> PostProcessor::FilterTagDetRes( - std::vector>> boxes, float ratio_h, - float ratio_w, const std::map> &im_info) { - int oriimg_h = im_info.at("input_shape")[0]; - int oriimg_w = im_info.at("input_shape")[1]; + std::vector>> boxes, + const std::array& det_img_info) { + int oriimg_w = det_img_info[0]; + int oriimg_h = det_img_info[1]; + float ratio_w = float(det_img_info[2])/float(oriimg_w); + float ratio_h = float(det_img_info[3])/float(oriimg_h); std::vector>> root_points; for (int n = 0; n < boxes.size(); n++) { diff --git a/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h b/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h index 8422704fd..5900daea2 100644 --- a/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h +++ b/fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h @@ -57,9 +57,8 @@ class PostProcessor { const float &det_db_unclip_ratio, const std::string &det_db_score_mode); std::vector>> FilterTagDetRes( - std::vector>> boxes, float ratio_h, - float ratio_w, - const std::map> &im_info); + std::vector>> boxes, + const std::array& det_img_info); private: static bool XsortInt(std::vector a, std::vector b); diff --git a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h index 36994389b..0e5c040eb 100644 --- a/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h +++ b/fastdeploy/vision/ocr/ppocr/utils/ocr_utils.h @@ -28,10 +28,10 @@ namespace fastdeploy { namespace vision { namespace ocr { -cv::Mat GetRotateCropImage(const cv::Mat& srcimage, +FASTDEPLOY_DECL cv::Mat GetRotateCropImage(const cv::Mat& srcimage, const std::array& box); -void SortBoxes(OCRResult* result); +FASTDEPLOY_DECL void SortBoxes(std::vector>* boxes); } // namespace ocr } // namespace vision diff --git a/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc b/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc index a6b723d94..4705d9a5a 100644 --- a/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc +++ b/fastdeploy/vision/ocr/ppocr/utils/sorted_boxes.cc @@ -29,17 +29,17 @@ bool CompareBox(const std::array& result1, } } -void SortBoxes(OCRResult* result) { - std::sort(result->boxes.begin(), result->boxes.end(), CompareBox); +void SortBoxes(std::vector>* boxes) { + std::sort(boxes->begin(), boxes->end(), CompareBox); - if (result->boxes.size() == 0) { + if (boxes->size() == 0) { return; } - for (int i = 0; i < result->boxes.size() - 1; i++) { - if (abs(result->boxes[i + 1][1] - result->boxes[i][1]) < 10 && - (result->boxes[i + 1][0] < result->boxes[i][0])) { - std::swap(result->boxes[i], result->boxes[i + 1]); + for (int i = 0; i < boxes->size() - 1; i++) { + if (abs((*boxes)[i + 1][1] - (*boxes)[i][1]) < 10 && + ((*boxes)[i + 1][0] < (*boxes)[i][0])) { + std::swap((*boxes)[i], (*boxes)[i + 1]); } } } diff --git a/python/fastdeploy/vision/ocr/__init__.py b/python/fastdeploy/vision/ocr/__init__.py old mode 100644 new mode 100755 index 98e210d3b..c83c6d678 --- a/python/fastdeploy/vision/ocr/__init__.py +++ b/python/fastdeploy/vision/ocr/__init__.py @@ -13,10 +13,4 @@ # limitations under the License. from __future__ import absolute_import -from .ppocr import PPOCRv3 -from .ppocr import PPOCRv2 -from .ppocr import PPOCRSystemv3 -from .ppocr import PPOCRSystemv2 -from .ppocr import DBDetector -from .ppocr import Classifier -from .ppocr import Recognizer +from .ppocr import * diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py old mode 100644 new mode 100755 index e361a3a8a..b8f5c81d1 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -41,40 +41,11 @@ class DBDetector(FastDeployModel): assert self.initialized, "DBDetector initialize failed." # 一些跟DBDetector模型有关的属性封装 - @property - def max_side_len(self): - return self._model.max_side_len - + ''' @property def det_db_thresh(self): return self._model.det_db_thresh - @property - def det_db_box_thresh(self): - return self._model.det_db_box_thresh - - @property - def det_db_unclip_ratio(self): - return self._model.det_db_unclip_ratio - - @property - def det_db_score_mode(self): - return self._model.det_db_score_mode - - @property - def use_dilation(self): - return self._model.use_dilation - - @property - def is_scale(self): - return self._model.max_wh - - @max_side_len.setter - def max_side_len(self, value): - assert isinstance( - value, int), "The value to set `max_side_len` must be type of int." - self._model.max_side_len = value - @det_db_thresh.setter def det_db_thresh(self, value): assert isinstance( @@ -82,6 +53,10 @@ class DBDetector(FastDeployModel): float), "The value to set `det_db_thresh` must be type of float." self._model.det_db_thresh = value + @property + def det_db_box_thresh(self): + return self._model.det_db_box_thresh + @det_db_box_thresh.setter def det_db_box_thresh(self, value): assert isinstance( @@ -89,6 +64,10 @@ class DBDetector(FastDeployModel): ), "The value to set `det_db_box_thresh` must be type of float." self._model.det_db_box_thresh = value + @property + def det_db_unclip_ratio(self): + return self._model.det_db_unclip_ratio + @det_db_unclip_ratio.setter def det_db_unclip_ratio(self, value): assert isinstance( @@ -96,6 +75,10 @@ class DBDetector(FastDeployModel): ), "The value to set `det_db_unclip_ratio` must be type of float." self._model.det_db_unclip_ratio = value + @property + def det_db_score_mode(self): + return self._model.det_db_score_mode + @det_db_score_mode.setter def det_db_score_mode(self, value): assert isinstance( @@ -103,6 +86,10 @@ class DBDetector(FastDeployModel): str), "The value to set `det_db_score_mode` must be type of str." self._model.det_db_score_mode = value + @property + def use_dilation(self): + return self._model.use_dilation + @use_dilation.setter def use_dilation(self, value): assert isinstance( @@ -110,11 +97,26 @@ class DBDetector(FastDeployModel): bool), "The value to set `use_dilation` must be type of bool." self._model.use_dilation = value + @property + def max_side_len(self): + return self._model.max_side_len + + @max_side_len.setter + def max_side_len(self, value): + assert isinstance( + value, int), "The value to set `max_side_len` must be type of int." + self._model.max_side_len = value + + @property + def is_scale(self): + return self._model.max_wh + @is_scale.setter def is_scale(self, value): assert isinstance( value, bool), "The value to set `is_scale` must be type of bool." self._model.is_scale = value + ''' class Classifier(FastDeployModel): @@ -139,6 +141,7 @@ class Classifier(FastDeployModel): model_file, params_file, self._runtime_option, model_format) assert self.initialized, "Classifier initialize failed." + ''' @property def cls_thresh(self): return self._model.cls_thresh @@ -170,6 +173,7 @@ class Classifier(FastDeployModel): value, int), "The value to set `cls_batch_num` must be type of int." self._model.cls_batch_num = value + ''' class Recognizer(FastDeployModel): @@ -197,6 +201,7 @@ class Recognizer(FastDeployModel): model_format) assert self.initialized, "Recognizer initialize failed." + ''' @property def rec_img_h(self): return self._model.rec_img_h @@ -227,6 +232,7 @@ class Recognizer(FastDeployModel): value, int), "The value to set `rec_batch_num` must be type of int." self._model.rec_batch_num = value + ''' class PPOCRv3(FastDeployModel): @@ -253,6 +259,14 @@ class PPOCRv3(FastDeployModel): """ return self.system.predict(input_image) + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: OCRBatchResult + """ + + return self.system.batch_predict(images) + class PPOCRSystemv3(PPOCRv3): def __init__(self, det_model=None, cls_model=None, rec_model=None): @@ -289,6 +303,14 @@ class PPOCRv2(FastDeployModel): """ return self.system.predict(input_image) + def batch_predict(self, images): + """Predict a batch of input image + :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format + :return: OCRBatchResult + """ + + return self.system.batch_predict(images) + class PPOCRSystemv2(PPOCRv2): def __init__(self, det_model=None, cls_model=None, rec_model=None): @@ -299,3 +321,93 @@ class PPOCRSystemv2(PPOCRv2): def predict(self, input_image): return super(PPOCRSystemv2, self).predict(input_image) + + +class DBDetectorPreprocessor: + def __init__(self): + """Create a preprocessor for DBDetectorModel + """ + self._preprocessor = C.vision.ocr.DBDetectorPreprocessor() + + def run(self, input_ims): + """Preprocess input images for DBDetectorModel + :param: input_ims: (list of numpy.ndarray) The input image + :return: pair(list of FDTensor, list of std::array) + """ + return self._preprocessor.run(input_ims) + + +class DBDetectorPostprocessor: + def __init__(self): + """Create a postprocessor for DBDetectorModel + """ + self._postprocessor = C.vision.ocr.DBDetectorPostprocessor() + + def run(self, runtime_results, batch_det_img_info): + """Postprocess the runtime results for DBDetectorModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :param: batch_det_img_info: (list of std::array)The output of det_preprocessor + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results, batch_det_img_info) + + +class RecognizerPreprocessor: + def __init__(self): + """Create a preprocessor for RecognizerModel + """ + self._preprocessor = C.vision.ocr.RecognizerPreprocessor() + + def run(self, input_ims): + """Preprocess input images for RecognizerModel + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + +class RecognizerPostprocessor: + def __init__(self, label_path): + """Create a postprocessor for RecognizerModel + :param label_path: (str)Path of label file + """ + self._postprocessor = C.vision.ocr.RecognizerPostprocessor(label_path) + + def run(self, runtime_results): + """Postprocess the runtime results for RecognizerModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + +class ClassifierPreprocessor: + def __init__(self): + """Create a preprocessor for ClassifierModel + """ + self._preprocessor = C.vision.ocr.ClassifierPreprocessor() + + def run(self, input_ims): + """Preprocess input images for ClassifierModel + :param: input_ims: (list of numpy.ndarray)The input image + :return: list of FDTensor + """ + return self._preprocessor.run(input_ims) + + +class ClassifierPostprocessor: + def __init__(self): + """Create a postprocessor for ClassifierModel + """ + self._postprocessor = C.vision.ocr.ClassifierPostprocessor() + + def run(self, runtime_results): + """Postprocess the runtime results for ClassifierModel + :param: runtime_results: (list of FDTensor or list of pyArray)The output FDTensor results from runtime + :return: list of Result(If the runtime_results is predict by batched samples, the length of this list equals to the batch size) + """ + return self._postprocessor.run(runtime_results) + + +def sort_boxes(boxes): + return C.vision.ocr.sort_boxes(boxes)