mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] Support PP-OCRv4 pipeline (#1913)
* 添加paddleclas模型 * 更新README_CN * 更新README_CN * 更新README * update get_model.sh * update get_models.sh * update paddleseg models * update paddle_seg models * update paddle_seg models * modified test resources * update benchmark_gpu_trt.sh * add paddle detection * add paddledetection to benchmark * modified benchmark cmakelists * update benchmark scripts * modified benchmark function calling * modified paddledetection documents * upadte getmodels.sh * add PaddleDetectonModel * reset examples/paddledetection * resolve conflict * update pybind * resolve conflict * fix bug * delete debug mode * update checkarch log * update trt inputs example * Update README.md * add ppocr_v4 * update ppocr_v4 * update ocr_v4 * update ocr_v4 * update ocr_v4 * update ocr_v4 --------- Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
@@ -317,6 +317,30 @@ if __name__ == '__main__':
|
||||
runtime_option=rec_option)
|
||||
model = fd.vision.ocr.PPOCRv3(
|
||||
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
|
||||
elif "OCRv4" in args.model_dir:
|
||||
det_option = option
|
||||
if args.backend in ["trt", "paddle_trt"]:
|
||||
det_option.trt_option.set_shape(
|
||||
"x", [1, 3, 64, 64], [1, 3, 640, 640], [1, 3, 960, 960])
|
||||
det_model = fd.vision.ocr.DBDetector(
|
||||
det_model_file, det_params_file, runtime_option=det_option)
|
||||
cls_option = option
|
||||
if args.backend in ["trt", "paddle_trt"]:
|
||||
cls_option.trt_option.set_shape(
|
||||
"x", [1, 3, 48, 10], [10, 3, 48, 320], [64, 3, 48, 1024])
|
||||
cls_model = fd.vision.ocr.Classifier(
|
||||
cls_model_file, cls_params_file, runtime_option=cls_option)
|
||||
rec_option = option
|
||||
if args.backend in ["trt", "paddle_trt"]:
|
||||
rec_option.trt_option.set_shape(
|
||||
"x", [1, 3, 48, 10], [10, 3, 48, 320], [64, 3, 48, 2304])
|
||||
rec_model = fd.vision.ocr.Recognizer(
|
||||
rec_model_file,
|
||||
rec_params_file,
|
||||
rec_label_file,
|
||||
runtime_option=rec_option)
|
||||
model = fd.vision.ocr.PPOCRv4(
|
||||
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
|
||||
else:
|
||||
raise Exception("model {} not support now in ppocr series".format(
|
||||
args.model_dir))
|
||||
|
7
c_api/fastdeploy_capi/vision/types_internal.h
Normal file → Executable file
7
c_api/fastdeploy_capi/vision/types_internal.h
Normal file → Executable file
@@ -32,6 +32,7 @@
|
||||
#include "fastdeploy/vision/ocr/ppocr/structurev2_table.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v4.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_table.h"
|
||||
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||
|
||||
@@ -187,6 +188,9 @@ DEFINE_PIPELINE_MODEL_WRAPPER_STRUCT(PPOCRv2, ppocrv2_model);
|
||||
// PPOCRv3
|
||||
DEFINE_PIPELINE_MODEL_WRAPPER_STRUCT(PPOCRv3, ppocrv3_model);
|
||||
|
||||
// PPOCRv4
|
||||
DEFINE_PIPELINE_MODEL_WRAPPER_STRUCT(PPOCRv4, ppocrv4_model);
|
||||
|
||||
// PPStructureV2Table
|
||||
DEFINE_PIPELINE_MODEL_WRAPPER_STRUCT(PPStructureV2Table, ppstructurev2table_model);
|
||||
|
||||
@@ -400,6 +404,9 @@ DECLARE_PIPELINE_MODEL_FUNC_FOR_GET_PTR_FROM_WRAPPER(PPOCRv2, fd_ppocrv2_wrapper
|
||||
// PPOCRv3
|
||||
DECLARE_PIPELINE_MODEL_FUNC_FOR_GET_PTR_FROM_WRAPPER(PPOCRv3, fd_ppocrv3_wrapper);
|
||||
|
||||
// PPOCRv4
|
||||
DECLARE_PIPELINE_MODEL_FUNC_FOR_GET_PTR_FROM_WRAPPER(PPOCRv4, fd_ppocrv4_wrapper);
|
||||
|
||||
// PPStructureV2Table
|
||||
DECLARE_PIPELINE_MODEL_FUNC_FOR_GET_PTR_FROM_WRAPPER(PPStructureV2Table, fd_ppstructurev2_table_wrapper);
|
||||
|
||||
|
@@ -59,6 +59,7 @@
|
||||
#include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v4.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_table.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppstructurev2_layout.h"
|
||||
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
|
||||
|
@@ -672,8 +672,8 @@ std::string OCRResult::Str() {
|
||||
out = out + "]";
|
||||
|
||||
if (rec_scores.size() > 0) {
|
||||
out = out + "rec text: " + text[n] + " rec score:" +
|
||||
std::to_string(rec_scores[n]) + " ";
|
||||
out = out + "rec text: " + text[n] +
|
||||
" rec score:" + std::to_string(rec_scores[n]) + " ";
|
||||
}
|
||||
if (cls_labels.size() > 0) {
|
||||
out = out + "cls label: " + std::to_string(cls_labels[n]) +
|
||||
@@ -713,8 +713,8 @@ std::string OCRResult::Str() {
|
||||
cls_scores.size() > 0) {
|
||||
std::string out;
|
||||
for (int i = 0; i < rec_scores.size(); i++) {
|
||||
out = out + "rec text: " + text[i] + " rec score:" +
|
||||
std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "rec text: " + text[i] +
|
||||
" rec score:" + std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "cls label: " + std::to_string(cls_labels[i]) +
|
||||
" cls score: " + std::to_string(cls_scores[i]);
|
||||
out = out + "\n";
|
||||
@@ -733,8 +733,8 @@ std::string OCRResult::Str() {
|
||||
cls_scores.size() == 0) {
|
||||
std::string out;
|
||||
for (int i = 0; i < rec_scores.size(); i++) {
|
||||
out = out + "rec text: " + text[i] + " rec score:" +
|
||||
std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "rec text: " + text[i] +
|
||||
" rec score:" + std::to_string(rec_scores[i]) + " ";
|
||||
out = out + "\n";
|
||||
}
|
||||
return out;
|
||||
@@ -781,9 +781,9 @@ std::string HeadPoseResult::Str() {
|
||||
std::string out;
|
||||
|
||||
out = "HeadPoseResult: [yaw, pitch, roll]\n";
|
||||
out = out + "yaw: " + std::to_string(euler_angles[0]) + "\n" + "pitch: " +
|
||||
std::to_string(euler_angles[1]) + "\n" + "roll: " +
|
||||
std::to_string(euler_angles[2]) + "\n";
|
||||
out = out + "yaw: " + std::to_string(euler_angles[0]) + "\n" +
|
||||
"pitch: " + std::to_string(euler_angles[1]) + "\n" +
|
||||
"roll: " + std::to_string(euler_angles[2]) + "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
2
fastdeploy/vision/ocr/ocr_pybind.cc
Normal file → Executable file
2
fastdeploy/vision/ocr/ocr_pybind.cc
Normal file → Executable file
@@ -17,6 +17,7 @@
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindPPOCRModel(pybind11::module& m);
|
||||
void BindPPOCRv4(pybind11::module& m);
|
||||
void BindPPOCRv3(pybind11::module& m);
|
||||
void BindPPOCRv2(pybind11::module& m);
|
||||
void BindPPStructureV2Table(pybind11::module& m);
|
||||
@@ -24,6 +25,7 @@ void BindPPStructureV2Table(pybind11::module& m);
|
||||
void BindOcr(pybind11::module& m) {
|
||||
auto ocr_module = m.def_submodule("ocr", "Module to deploy OCR models");
|
||||
BindPPOCRModel(ocr_module);
|
||||
BindPPOCRv4(ocr_module);
|
||||
BindPPOCRv3(ocr_module);
|
||||
BindPPOCRv2(ocr_module);
|
||||
BindPPStructureV2Table(ocr_module);
|
||||
|
32
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc
Normal file → Executable file
32
fastdeploy/vision/ocr/ppocr/ppocr_pybind.cc
Normal file → Executable file
@@ -16,6 +16,38 @@
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPPOCRv4(pybind11::module& m) {
|
||||
// PPOCRv4
|
||||
pybind11::class_<pipeline::PPOCRv4, FastDeployModel>(m, "PPOCRv4")
|
||||
|
||||
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
|
||||
fastdeploy::vision::ocr::Classifier*,
|
||||
fastdeploy::vision::ocr::Recognizer*>())
|
||||
.def(pybind11::init<fastdeploy::vision::ocr::DBDetector*,
|
||||
fastdeploy::vision::ocr::Recognizer*>())
|
||||
.def_property("cls_batch_size", &pipeline::PPOCRv4::GetClsBatchSize,
|
||||
&pipeline::PPOCRv4::SetClsBatchSize)
|
||||
.def_property("rec_batch_size", &pipeline::PPOCRv4::GetRecBatchSize,
|
||||
&pipeline::PPOCRv4::SetRecBatchSize)
|
||||
.def("clone", [](pipeline::PPOCRv4& self) { return self.Clone(); })
|
||||
.def("predict",
|
||||
[](pipeline::PPOCRv4& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::OCRResult res;
|
||||
self.Predict(&mat, &res);
|
||||
return res;
|
||||
})
|
||||
.def("batch_predict",
|
||||
[](pipeline::PPOCRv4& self, std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
}
|
||||
std::vector<vision::OCRResult> results;
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
});
|
||||
}
|
||||
void BindPPOCRv3(pybind11::module& m) {
|
||||
// PPOCRv3
|
||||
pybind11::class_<pipeline::PPOCRv3, FastDeployModel>(m, "PPOCRv3")
|
||||
|
80
fastdeploy/vision/ocr/ppocr/ppocr_v4.h
Executable file
80
fastdeploy/vision/ocr/ppocr/ppocr_v4.h
Executable file
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
/** \brief This pipeline can launch detection model, classification model and recognition model sequentially. All OCR pipeline APIs are defined inside this namespace.
|
||||
*
|
||||
*/
|
||||
namespace pipeline {
|
||||
/*! @brief PPOCRv4 is used to load PP-OCRv4 series models provided by PaddleOCR.
|
||||
*/
|
||||
class FASTDEPLOY_DECL PPOCRv4 : public PPOCRv3 {
|
||||
public:
|
||||
/** \brief Set up the detection model path, classification model path and recognition model path respectively.
|
||||
*
|
||||
* \param[in] det_model Path of detection model, e.g ./ch_PP-OCRv4_det_infer
|
||||
* \param[in] cls_model Path of classification model, e.g ./ch_ppocr_mobile_v2.0_cls_infer
|
||||
* \param[in] rec_model Path of recognition model, e.g ./ch_PP-OCRv4_rec_infer
|
||||
*/
|
||||
PPOCRv4(fastdeploy::vision::ocr::DBDetector* det_model,
|
||||
fastdeploy::vision::ocr::Classifier* cls_model,
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: PPOCRv3(det_model, cls_model, rec_model) {
|
||||
// The only difference between v2 and v3
|
||||
auto preprocess_shape = recognizer_->GetPreprocessor().GetRecImageShape();
|
||||
preprocess_shape[1] = 48;
|
||||
recognizer_->GetPreprocessor().SetRecImageShape(preprocess_shape);
|
||||
}
|
||||
/** \brief Classification model is optional, so this function is set up the detection model path and recognition model path respectively.
|
||||
*
|
||||
* \param[in] det_model Path of detection model, e.g ./ch_PP-OCRv4_det_infer
|
||||
* \param[in] rec_model Path of recognition model, e.g ./ch_PP-OCRv4_rec_infer
|
||||
*/
|
||||
PPOCRv4(fastdeploy::vision::ocr::DBDetector* det_model,
|
||||
fastdeploy::vision::ocr::Recognizer* rec_model)
|
||||
: PPOCRv3(det_model, rec_model) {
|
||||
// The only difference between v2 and v4
|
||||
auto preprocess_shape = recognizer_->GetPreprocessor().GetRecImageShape();
|
||||
preprocess_shape[1] = 48;
|
||||
recognizer_->GetPreprocessor().SetRecImageShape(preprocess_shape);
|
||||
}
|
||||
|
||||
/** \brief Clone a new PPOCRv4 with less memory usage when multiple instances of the same model are created
|
||||
*
|
||||
* \return new PPOCRv4* type unique pointer
|
||||
*/
|
||||
std::unique_ptr<PPOCRv4> Clone() const {
|
||||
std::unique_ptr<PPOCRv4> clone_model = utils::make_unique<PPOCRv4>(PPOCRv4(*this));
|
||||
clone_model->detector_ = detector_->Clone().release();
|
||||
if (classifier_ != nullptr) {
|
||||
clone_model->classifier_ = classifier_->Clone().release();
|
||||
}
|
||||
clone_model->recognizer_ = recognizer_->Clone().release();
|
||||
return clone_model;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace pipeline
|
||||
|
||||
namespace application {
|
||||
namespace ocrsystem {
|
||||
typedef pipeline::PPOCRv4 PPOCRSystemv4;
|
||||
} // namespace ocrsystem
|
||||
} // namespace application
|
||||
|
||||
} // namespace fastdeploy
|
@@ -92,6 +92,7 @@ FASTDEPLOY_DECL cv::Mat VisDetection(const cv::Mat& im,
|
||||
int line_size = 1, float font_size = 0.5f,
|
||||
std::vector<int> font_color = {255, 255, 255},
|
||||
int font_thickness = 1);
|
||||
|
||||
/** \brief Show the visualized results with custom labels for detection models
|
||||
*
|
||||
* \param[in] im the input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
|
@@ -848,6 +848,79 @@ class StructureV2Layout(FastDeployModel):
|
||||
def postprocessor(self, value):
|
||||
self._model.postprocessor = value
|
||||
|
||||
class PPOCRv4(FastDeployModel):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
"""Consruct a pipeline with text detector, direction classifier and text recognizer models
|
||||
|
||||
:param det_model: (FastDeployModel) The detection model object created by fastdeploy.vision.ocr.DBDetector.
|
||||
:param cls_model: (FastDeployModel) The classification model object created by fastdeploy.vision.ocr.Classifier.
|
||||
:param rec_model: (FastDeployModel) The recognition model object created by fastdeploy.vision.ocr.Recognizer.
|
||||
"""
|
||||
assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None."
|
||||
if cls_model is None:
|
||||
self.system_ = C.vision.ocr.PPOCRv4(det_model._model,
|
||||
rec_model._model)
|
||||
else:
|
||||
self.system_ = C.vision.ocr.PPOCRv4(
|
||||
det_model._model, cls_model._model, rec_model._model)
|
||||
|
||||
def clone(self):
|
||||
"""Clone PPOCRv4 pipeline object
|
||||
:return: a new PPOCRv4 pipeline object
|
||||
"""
|
||||
|
||||
class PPOCRv4Clone(PPOCRv4):
|
||||
def __init__(self, system):
|
||||
self.system_ = system
|
||||
|
||||
clone_model = PPOCRv4Clone(self.system_.clone())
|
||||
return clone_model
|
||||
|
||||
def predict(self, input_image):
|
||||
"""Predict an input image
|
||||
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
|
||||
:return: OCRResult
|
||||
"""
|
||||
return self.system_.predict(input_image)
|
||||
|
||||
def batch_predict(self, images):
|
||||
"""Predict a batch of input image
|
||||
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||
:return: OCRBatchResult
|
||||
"""
|
||||
return self.system_.batch_predict(images)
|
||||
|
||||
@property
|
||||
def cls_batch_size(self):
|
||||
return self.system_.cls_batch_size
|
||||
|
||||
@cls_batch_size.setter
|
||||
def cls_batch_size(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
int), "The value to set `cls_batch_size` must be type of int."
|
||||
self.system_.cls_batch_size = value
|
||||
|
||||
@property
|
||||
def rec_batch_size(self):
|
||||
return self.system_.rec_batch_size
|
||||
|
||||
@rec_batch_size.setter
|
||||
def rec_batch_size(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
int), "The value to set `rec_batch_size` must be type of int."
|
||||
self.system_.rec_batch_size = value
|
||||
|
||||
class PPOCRSystemv4(PPOCRv4):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
logging.warning(
|
||||
"DEPRECATED: fd.vision.ocr.PPOCRSystemv4 is deprecated, "
|
||||
"please use fd.vision.ocr.PPOCRv4 instead.")
|
||||
super(PPOCRSystemv4, self).__init__(det_model, cls_model, rec_model)
|
||||
|
||||
def predict(self, input_image):
|
||||
return super(PPOCRSystemv4, self).predict(input_image)
|
||||
|
||||
class PPOCRv3(FastDeployModel):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
|
@@ -63,6 +63,8 @@ def vis_perception(im_data,
|
||||
score_threshold, line_size, font_size)
|
||||
|
||||
|
||||
|
||||
|
||||
def vis_keypoint_detection(im_data, keypoint_det_result, conf_threshold=0.5):
|
||||
"""Show the visualized results for keypoint detection models
|
||||
|
||||
|
Reference in New Issue
Block a user