[Other] PPOCR models support model clone function (#1072)

* Refactor PaddleSeg with preprocessor && postprocessor

* Fix bugs

* Delete redundancy code

* Modify by comments

* Refactor according to comments

* Add batch evaluation

* Add single test script

* Add ppliteseg single test script && fix eval(raise) error

* fix bug

* Fix evaluation segmentation.py batch predict

* Fix segmentation evaluation bug

* Fix evaluation segmentation bugs

* Update segmentation result docs

* Update old predict api and DisableNormalizeAndPermute

* Update resize segmentation label map with cv::INTER_NEAREST

* Add Model Clone function for PaddleClas && PaddleDet && PaddleSeg

* Add multi thread demo

* Add python model clone function

* Add multi thread python && C++ example

* Fix bug

* Update python && cpp multi_thread examples

* Add cpp && python directory

* Add README.md for examples

* Delete redundant code

* Create README_CN.md

* Rename README_CN.md to README.md

* Update README.md

* Update README.md

* Update VERSION_NUMBER

* Update requirements.txt

* Update README.md

* update version in doc:

* [Serving]Update Dockerfile (#1037)

Update Dockerfile

* Add license notice for RVM onnx model file (#1060)

* [Model] Add GPL-3.0 license (#1065)

Add GPL-3.0 license

* PPOCR model support model clone

* Update README.md

* Update PPOCRv2 && PPOCRv3 clone code

* Update PPOCR python __init__

* Add multi thread ocr example code

* Update README.md

* Update README.md

* Update ResNet50_vd_infer multi process code

* Add PPOCR multi process && thread example

* Update README.md

* Update README.md

* Update multi-thread docs

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: leiqing <54695910+leiqing1@users.noreply.github.com>
Co-authored-by: heliqi <1101791222@qq.com>
Co-authored-by: WJJ1995 <wjjisloser@163.com>
This commit is contained in:
huangjianhui
2023-01-17 15:16:41 +08:00
committed by GitHub
parent abba2afd74
commit 6c4a08e416
28 changed files with 1201 additions and 96 deletions

8
examples/vision/matting/rvm/README.md Executable file → Normal file
View File

@@ -17,10 +17,10 @@ For developers' testing, models exported by RobustVideoMatting are provided belo
| Model | Parameter Size | Accuracy | Note | | Model | Parameter Size | Accuracy | Note |
|:---------------------------------------------------------------- |:----- |:----- | :------ | |:---------------------------------------------------------------- |:----- |:----- | :------ |
| [rvm_mobilenetv3_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx) | 15MB | - | | [rvm_mobilenetv3_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx) | 15MB ||exported from [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)GPL-3.0 License |
| [rvm_resnet50_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_fp32.onnx) | 103MB | - | | [rvm_resnet50_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_fp32.onnx) | 103MB | |exported from [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)GPL-3.0 License |
| [rvm_mobilenetv3_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx) | 15MB | - | | [rvm_mobilenetv3_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx) | 15MB | |exported from [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)GPL-3.0 License |
| [rvm_resnet50_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_trt.onnx) | 103MB | - | | [rvm_resnet50_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_trt.onnx) | 103MB | |exported from [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)GPL-3.0 License |
**Note** **Note**
- If you want to use TensorRT for inference, download onnx model file with the trt suffix is necessary. - If you want to use TensorRT for inference, download onnx model file with the trt suffix is necessary.

View File

@@ -53,6 +53,12 @@ bool Classifier::Initialize() {
return true; return true;
} }
std::unique_ptr<Classifier> Classifier::Clone() const {
std::unique_ptr<Classifier> clone_model = utils::make_unique<Classifier>(Classifier(*this));
clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model;
}
bool Classifier::Predict(const cv::Mat& img, int32_t* cls_label, float* cls_score) { bool Classifier::Predict(const cv::Mat& img, int32_t* cls_label, float* cls_score) {
std::vector<int32_t> cls_labels(1); std::vector<int32_t> cls_labels(1);
std::vector<float> cls_scores(1); std::vector<float> cls_scores(1);

View File

@@ -19,6 +19,7 @@
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h" #include "fastdeploy/vision/ocr/ppocr/cls_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h" #include "fastdeploy/vision/ocr/ppocr/cls_preprocessor.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -41,6 +42,13 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
Classifier(const std::string& model_file, const std::string& params_file = "", Classifier(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(), const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/** \brief Clone a new Classifier with less memory usage when multiple instances of the same model are created
*
* \return new Classifier* type unique pointer
*/
virtual std::unique_ptr<Classifier> Clone() const;
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_cls"; } std::string ModelName() const { return "ppocr/ocr_cls"; }
@@ -53,6 +61,7 @@ class FASTDEPLOY_DECL Classifier : public FastDeployModel {
*/ */
virtual bool Predict(const cv::Mat& img, virtual bool Predict(const cv::Mat& img,
int32_t* cls_label, float* cls_score); int32_t* cls_label, float* cls_score);
/** \brief BatchPredict the input image and get OCR classification model cls_result. /** \brief BatchPredict the input image and get OCR classification model cls_result.
* *
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.

View File

@@ -53,6 +53,12 @@ bool DBDetector::Initialize() {
return true; return true;
} }
std::unique_ptr<DBDetector> DBDetector::Clone() const {
std::unique_ptr<DBDetector> clone_model = utils::make_unique<DBDetector>(DBDetector(*this));
clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model;
}
bool DBDetector::Predict(const cv::Mat& img, bool DBDetector::Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result) { std::vector<std::array<int, 8>>* boxes_result) {
std::vector<std::vector<std::array<int, 8>>> det_results; std::vector<std::vector<std::array<int, 8>>> det_results;

View File

@@ -19,6 +19,7 @@
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h" #include "fastdeploy/vision/ocr/ppocr/det_postprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h" #include "fastdeploy/vision/ocr/ppocr/det_preprocessor.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -42,8 +43,16 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
DBDetector(const std::string& model_file, const std::string& params_file = "", DBDetector(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(), const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/** \brief Clone a new DBDetector with less memory usage when multiple instances of the same model are created
*
* \return new DBDetector* type unique pointer
*/
virtual std::unique_ptr<DBDetector> Clone() const;
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_det"; } std::string ModelName() const { return "ppocr/ocr_det"; }
/** \brief Predict the input image and get OCR detection model result. /** \brief Predict the input image and get OCR detection model result.
* *
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -52,6 +61,7 @@ class FASTDEPLOY_DECL DBDetector : public FastDeployModel {
*/ */
virtual bool Predict(const cv::Mat& img, virtual bool Predict(const cv::Mat& img,
std::vector<std::array<int, 8>>* boxes_result); std::vector<std::array<int, 8>>* boxes_result);
/** \brief BatchPredict the input image and get OCR detection model result. /** \brief BatchPredict the input image and get OCR detection model result.
* *
* \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list input of image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.

View File

@@ -27,6 +27,9 @@ void BindPPOCRv3(pybind11::module& m) {
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def_property("cls_batch_size", &pipeline::PPOCRv3::GetClsBatchSize, &pipeline::PPOCRv3::SetClsBatchSize) .def_property("cls_batch_size", &pipeline::PPOCRv3::GetClsBatchSize, &pipeline::PPOCRv3::SetClsBatchSize)
.def_property("rec_batch_size", &pipeline::PPOCRv3::GetRecBatchSize, &pipeline::PPOCRv3::SetRecBatchSize) .def_property("rec_batch_size", &pipeline::PPOCRv3::GetRecBatchSize, &pipeline::PPOCRv3::SetRecBatchSize)
.def("clone", [](pipeline::PPOCRv3& self) {
return self.Clone();
})
.def("predict", [](pipeline::PPOCRv3& self, .def("predict", [](pipeline::PPOCRv3& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);
@@ -56,6 +59,9 @@ void BindPPOCRv2(pybind11::module& m) {
fastdeploy::vision::ocr::Recognizer*>()) fastdeploy::vision::ocr::Recognizer*>())
.def_property("cls_batch_size", &pipeline::PPOCRv2::GetClsBatchSize, &pipeline::PPOCRv2::SetClsBatchSize) .def_property("cls_batch_size", &pipeline::PPOCRv2::GetClsBatchSize, &pipeline::PPOCRv2::SetClsBatchSize)
.def_property("rec_batch_size", &pipeline::PPOCRv2::GetRecBatchSize, &pipeline::PPOCRv2::SetRecBatchSize) .def_property("rec_batch_size", &pipeline::PPOCRv2::GetRecBatchSize, &pipeline::PPOCRv2::SetRecBatchSize)
.def("clone", [](pipeline::PPOCRv2& self) {
return self.Clone();
})
.def("predict", [](pipeline::PPOCRv2& self, .def("predict", [](pipeline::PPOCRv2& self,
pybind11::array& data) { pybind11::array& data) {
auto mat = PyArrayToCvMat(data); auto mat = PyArrayToCvMat(data);

View File

@@ -74,6 +74,17 @@ bool PPOCRv2::Initialized() const {
} }
return true; return true;
} }
std::unique_ptr<PPOCRv2> PPOCRv2::Clone() const {
std::unique_ptr<PPOCRv2> clone_model = utils::make_unique<PPOCRv2>(PPOCRv2(*this));
clone_model->detector_ = detector_->Clone().release();
if (classifier_ != nullptr) {
clone_model->classifier_ = classifier_->Clone().release();
}
clone_model->recognizer_ = recognizer_->Clone().release();
return clone_model;
}
bool PPOCRv2::Predict(cv::Mat* img, bool PPOCRv2::Predict(cv::Mat* img,
fastdeploy::vision::OCRResult* result) { fastdeploy::vision::OCRResult* result) {
return Predict(*img, result); return Predict(*img, result);

View File

@@ -24,6 +24,7 @@
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy { namespace fastdeploy {
/** \brief This pipeline can launch detection model, classification model and recognition model sequentially. All OCR pipeline APIs are defined inside this namespace. /** \brief This pipeline can launch detection model, classification model and recognition model sequentially. All OCR pipeline APIs are defined inside this namespace.
@@ -52,6 +53,12 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model, PPOCRv2(fastdeploy::vision::ocr::DBDetector* det_model,
fastdeploy::vision::ocr::Recognizer* rec_model); fastdeploy::vision::ocr::Recognizer* rec_model);
/** \brief Clone a new PPOCRv2 with less memory usage when multiple instances of the same model are created
*
* \return new PPOCRv2* type unique pointer
*/
std::unique_ptr<PPOCRv2> Clone() const;
/** \brief Predict the input image and get OCR result. /** \brief Predict the input image and get OCR result.
* *
* \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] im The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -69,6 +76,7 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
*/ */
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<fastdeploy::vision::OCRResult>* batch_result); std::vector<fastdeploy::vision::OCRResult>* batch_result);
bool Initialized() const override; bool Initialized() const override;
bool SetClsBatchSize(int cls_batch_size); bool SetClsBatchSize(int cls_batch_size);
int GetClsBatchSize(); int GetClsBatchSize();
@@ -83,7 +91,6 @@ class FASTDEPLOY_DECL PPOCRv2 : public FastDeployModel {
private: private:
int cls_batch_size_ = 1; int cls_batch_size_ = 1;
int rec_batch_size_ = 6; int rec_batch_size_ = 6;
/// Launch the detection process in OCR.
}; };
namespace application { namespace application {

View File

@@ -49,6 +49,20 @@ class FASTDEPLOY_DECL PPOCRv3 : public PPOCRv2 {
// The only difference between v2 and v3 // The only difference between v2 and v3
recognizer_->GetPreprocessor().rec_image_shape_[1] = 48; recognizer_->GetPreprocessor().rec_image_shape_[1] = 48;
} }
/** \brief Clone a new PPOCRv3 with less memory usage when multiple instances of the same model are created
*
* \return new PPOCRv3* type unique pointer
*/
std::unique_ptr<PPOCRv3> Clone() const {
std::unique_ptr<PPOCRv3> clone_model = utils::make_unique<PPOCRv3>(PPOCRv3(*this));
clone_model->detector_ = detector_->Clone().release();
if (classifier_ != nullptr) {
clone_model->classifier_ = classifier_->Clone().release();
}
clone_model->recognizer_ = recognizer_->Clone().release();
return clone_model;
}
}; };
} // namespace pipeline } // namespace pipeline

View File

@@ -56,6 +56,12 @@ bool Recognizer::Initialize() {
return true; return true;
} }
std::unique_ptr<Recognizer> Recognizer::Clone() const {
std::unique_ptr<Recognizer> clone_model = utils::make_unique<Recognizer>(Recognizer(*this));
clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model;
}
bool Recognizer::Predict(const cv::Mat& img, std::string* text, float* rec_score) { bool Recognizer::Predict(const cv::Mat& img, std::string* text, float* rec_score) {
std::vector<std::string> texts(1); std::vector<std::string> texts(1);
std::vector<float> rec_scores(1); std::vector<float> rec_scores(1);

View File

@@ -19,6 +19,7 @@
#include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h" #include "fastdeploy/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
#include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h" #include "fastdeploy/vision/ocr/ppocr/rec_preprocessor.h"
#include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h" #include "fastdeploy/vision/ocr/ppocr/rec_postprocessor.h"
#include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy { namespace fastdeploy {
namespace vision { namespace vision {
@@ -43,8 +44,16 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
const std::string& label_path = "", const std::string& label_path = "",
const RuntimeOption& custom_option = RuntimeOption(), const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE); const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name /// Get model's name
std::string ModelName() const { return "ppocr/ocr_rec"; } std::string ModelName() const { return "ppocr/ocr_rec"; }
/** \brief Clone a new Recognizer with less memory usage when multiple instances of the same model are created
*
* \return new Recognizer* type unique pointer
*/
virtual std::unique_ptr<Recognizer> Clone() const;
/** \brief Predict the input image and get OCR recognition model result. /** \brief Predict the input image and get OCR recognition model result.
* *
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -53,6 +62,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
* \return true if the prediction is successed, otherwise false. * \return true if the prediction is successed, otherwise false.
*/ */
virtual bool Predict(const cv::Mat& img, std::string* text, float* rec_score); virtual bool Predict(const cv::Mat& img, std::string* text, float* rec_score);
/** \brief BatchPredict the input image and get OCR recognition model result. /** \brief BatchPredict the input image and get OCR recognition model result.
* *
* \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format. * \param[in] images The list of input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format.
@@ -62,6 +72,7 @@ class FASTDEPLOY_DECL Recognizer : public FastDeployModel {
*/ */
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores); std::vector<std::string>* texts, std::vector<float>* rec_scores);
virtual bool BatchPredict(const std::vector<cv::Mat>& images, virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<std::string>* texts, std::vector<float>* rec_scores, std::vector<std::string>* texts, std::vector<float>* rec_scores,
size_t start_index, size_t end_index, size_t start_index, size_t end_index,

View File

@@ -170,6 +170,17 @@ class DBDetector(FastDeployModel):
assert self.initialized, "DBDetector initialize failed." assert self.initialized, "DBDetector initialize failed."
self._runnable = True self._runnable = True
def clone(self):
"""Clone OCR detection model object
:return: a new OCR detection model object
"""
class DBDetectorClone(DBDetector):
def __init__(self, model):
self._model = model
clone_model = DBDetectorClone(self._model.clone())
return clone_model
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
@@ -406,6 +417,17 @@ class Classifier(FastDeployModel):
assert self.initialized, "Classifier initialize failed." assert self.initialized, "Classifier initialize failed."
self._runnable = True self._runnable = True
def clone(self):
"""Clone OCR classification model object
:return: a new OCR classification model object
"""
class ClassifierClone(Classifier):
def __init__(self, model):
self._model = model
clone_model = ClassifierClone(self._model.clone())
return clone_model
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
@@ -603,6 +625,17 @@ class Recognizer(FastDeployModel):
assert self.initialized, "Recognizer initialize failed." assert self.initialized, "Recognizer initialize failed."
self._runnable = True self._runnable = True
def clone(self):
"""Clone OCR recognition model object
:return: a new OCR recognition model object
"""
class RecognizerClone(Recognizer):
def __init__(self, model):
self._model = model
clone_model = RecognizerClone(self._model.clone())
return clone_model
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
@@ -700,47 +733,58 @@ class PPOCRv3(FastDeployModel):
""" """
assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None." assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None."
if cls_model is None: if cls_model is None:
self.system = C.vision.ocr.PPOCRv3(det_model._model, self.system_ = C.vision.ocr.PPOCRv3(det_model._model,
rec_model._model) rec_model._model)
else: else:
self.system = C.vision.ocr.PPOCRv3( self.system_ = C.vision.ocr.PPOCRv3(
det_model._model, cls_model._model, rec_model._model) det_model._model, cls_model._model, rec_model._model)
def clone(self):
"""Clone PPOCRv3 pipeline object
:return: a new PPOCRv3 pipeline object
"""
class PPOCRv3Clone(PPOCRv3):
def __init__(self, system):
self.system_ = system
clone_model = PPOCRv3Clone(self.system_.clone())
return clone_model
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: OCRResult :return: OCRResult
""" """
return self.system.predict(input_image) return self.system_.predict(input_image)
def batch_predict(self, images): def batch_predict(self, images):
"""Predict a batch of input image """Predict a batch of input image
:param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format :param images: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return: OCRBatchResult :return: OCRBatchResult
""" """
return self.system.batch_predict(images) return self.system_.batch_predict(images)
@property @property
def cls_batch_size(self): def cls_batch_size(self):
return self.system.cls_batch_size return self.system_.cls_batch_size
@cls_batch_size.setter @cls_batch_size.setter
def cls_batch_size(self, value): def cls_batch_size(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `cls_batch_size` must be type of int." int), "The value to set `cls_batch_size` must be type of int."
self.system.cls_batch_size = value self.system_.cls_batch_size = value
@property @property
def rec_batch_size(self): def rec_batch_size(self):
return self.system.rec_batch_size return self.system_.rec_batch_size
@rec_batch_size.setter @rec_batch_size.setter
def rec_batch_size(self, value): def rec_batch_size(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `rec_batch_size` must be type of int." int), "The value to set `rec_batch_size` must be type of int."
self.system.rec_batch_size = value self.system_.rec_batch_size = value
class PPOCRSystemv3(PPOCRv3): class PPOCRSystemv3(PPOCRv3):
@@ -764,19 +808,30 @@ class PPOCRv2(FastDeployModel):
""" """
assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None." assert det_model is not None and rec_model is not None, "The det_model and rec_model cannot be None."
if cls_model is None: if cls_model is None:
self.system = C.vision.ocr.PPOCRv2(det_model._model, self.system_ = C.vision.ocr.PPOCRv2(det_model._model,
rec_model._model) rec_model._model)
else: else:
self.system = C.vision.ocr.PPOCRv2( self.system_ = C.vision.ocr.PPOCRv2(
det_model._model, cls_model._model, rec_model._model) det_model._model, cls_model._model, rec_model._model)
def clone(self):
"""Clone PPOCRv3 pipeline object
:return: a new PPOCRv3 pipeline object
"""
class PPOCRv2Clone(PPOCRv2):
def __init__(self, system):
self.system_ = system
clone_model = PPOCRv2Clone(self.system_.clone())
return clone_model
def predict(self, input_image): def predict(self, input_image):
"""Predict an input image """Predict an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: OCRResult :return: OCRResult
""" """
return self.system.predict(input_image) return self.system_.predict(input_image)
def batch_predict(self, images): def batch_predict(self, images):
"""Predict a batch of input image """Predict a batch of input image
@@ -784,29 +839,29 @@ class PPOCRv2(FastDeployModel):
:return: OCRBatchResult :return: OCRBatchResult
""" """
return self.system.batch_predict(images) return self.system_.batch_predict(images)
@property @property
def cls_batch_size(self): def cls_batch_size(self):
return self.system.cls_batch_size return self.system_.cls_batch_size
@cls_batch_size.setter @cls_batch_size.setter
def cls_batch_size(self, value): def cls_batch_size(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `cls_batch_size` must be type of int." int), "The value to set `cls_batch_size` must be type of int."
self.system.cls_batch_size = value self.system_.cls_batch_size = value
@property @property
def rec_batch_size(self): def rec_batch_size(self):
return self.system.rec_batch_size return self.system_.rec_batch_size
@rec_batch_size.setter @rec_batch_size.setter
def rec_batch_size(self, value): def rec_batch_size(self, value):
assert isinstance( assert isinstance(
value, value,
int), "The value to set `rec_batch_size` must be type of int." int), "The value to set `rec_batch_size` must be type of int."
self.system.rec_batch_size = value self.system_.rec_batch_size = value
class PPOCRSystemv2(PPOCRv2): class PPOCRSystemv2(PPOCRv2):

View File

@@ -1,38 +1,42 @@
[English](README.md) | 中文 English | [中文](README_CN.md)
# FastDeploy模型多线程或多进程预测的使用 # Usage of FastDeploy model multi-thread or multi-process prediction
FastDeploy针对python和cpp开发者提供了以下多线程或多进程的示例 FastDeploy provides the following multi-thread or multi-process examples for python and cpp developers
- [python多线程以及多进程预测的使用示例](python) - [Example of using python multi-thread and multi-process prediction](python)
- [cpp多线程预测的使用示例](cpp) - [Example of using cpp multithreaded prediction](cpp)
## 目前支持多线程以及多进程预测的模型 ## Models that currently support multi-thread and multi-process predictions
| 任务类型 | 说明 | 模型下载链接 | | task type | illustrate | model download link |
|:-------------- |:----------------------------------- |:-------------------------------------------------------------------------------- | |:-------------- |:---------------- |:------------------- |
| Detection | 支持PaddleDetection系列模型 | [PaddleDetection](../../examples/vision/detection/paddledetection) | | Detection | support PaddleDetection series models | [PaddleDetection](../../examples/vision/detection/paddledetection) |
| Segmentation | 支持PaddleSeg系列模型 | [PaddleSeg](../../examples/vision/segmentation/paddleseg) | | Segmentation | support PaddleSeg series models | [PaddleSeg](../../examples/vision/segmentation/paddleseg) |
| Classification | 支持PaddleClas系列模型 | [PaddleClas](../../examples/vision/classification/paddleclas) | | Classification | support PaddleClas series models | [PaddleClas](../../examples/vision/classification/paddleclas) |
>> **注意**:点击上方模型下载链接,至`下载预训练模型`模块下载模型 | OCR | support PaddleOCR series models | [PaddleOCR](../../examples/vision/ocr/) |
## 多线程预测时克隆模型 >> **Notice**:
- click the model download link above to download the model from the `Download pre-training model` module
- OCR is a pipeline model. For multi-thread examples, please refer to the `pipeline` folder. Other single-model multi-thread examples are in the `single_model` folder.
针对一个视觉模型的推理包含3个环节 ## Clone model when using multi-thread prediction
- 输入图像图像经过预处理最终得到要输入给模型Runtime的Tensor即preprocess阶段
- 模型Runtime接收Tensor进行推理得到Runtime的输出Tensor即infer阶段
- 对Runtime的输出Tensor做后处理得到最后的结构化信息如DetectionResult, SegmentationResult等等即postprocess阶段
针对以上preprocess、infer、postprocess三个阶段FastDeploy分别抽象出了三个对应的类即Preprocessor、Runtime、PostProcessor the inference process of vision model is consist of three stages
- load the image, then the image is preprocessed, finally get the Tensor to be input to the model Runtime, that is the preprocess stage
- the model Runtime receives Tensor, do the inference, and obtains the output tensor of Runtime, that is the infer stage
- process the output tensor of Runtime to get the final structured information, such as DetectionResult, SegmentationResult, etc., that is the postprocess stage
在多线程调用FastDeploy中的模型进行并行推理的时候要考虑几个问题 For the above three stages: preprocess, inference, and postprocess, FastDeploy abstracted three corresponding classes, namely Preprocessor, Runtime, and PostProcessor
- Preprocessor、Runtime、Postprocessor三个类能否分别支持并行处理
When using FastDeploy for multi-thread inference, several issues should be considered
- Can the Preprocessor, Runtime, and Postprocessor support parallel processing respectively?
- 在支持多线程并发的前提下,能否最大限度的减少内存或显存占用 - 在支持多线程并发的前提下,能否最大限度的减少内存或显存占用
- Under the premise of supporting multi-thread concurrency, can the memory or video memory usage be minimized?
FastDeploy采用分别拷贝多个对象的方式,进行多线程推理,即每个线程都有一份独立的PreprocessorRuntimePostProcessor的实例化的对象。而为了减少内存的占用对于Runtime的拷贝则采用共享模型权重的方式进行拷贝。因此虽然复制了多个对象但对于模型权重和参数在内存或显存中只有一份。 FastDeploy adopts the method of copying multiple objects separately for multi-thread inference, so each thread has an independent instance of Preprocessor, Runtime, and PostProcessor. In order to reduce the memory usage, the Runtime adopt sharing the model weights copy method. In this way, the memory usage caused by copying multiple objects is reduced.
以此减少拷贝多个对象带来的内存占用。
FastDeploy提供如下接口来进行模型的clone(以PaddleClas为例) FastDeploy provides the following interface to clone the model (take PaddleClas as an example)
- Python: `PaddleClasModel.clone()` - Python: `PaddleClasModel.clone()`
- C++: `PaddleClasModel::Clone()` - C++: `PaddleClasModel::Clone()`
@@ -63,43 +67,43 @@ fastdeploy::vision::ClassifyResult res;
model->Predict(im, &res) model->Predict(im, &res)
``` ```
>> **注意**:其他模型类似API接口可查阅[官方C++文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/cpp/html/index.html)以及[官方Python文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/python/html/index.html) >> **Notice**:Other models API refer to[官方C++文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/cpp/html/index.html) and [官方Python文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/python/html/index.html)
## Python多线程以及多进程 ## Python multi-thread and multi-process
Python由于语言的限制即GIL锁的存在在计算密集型的场景下多线程无法充分利用硬件的性能。因此Python上提供多进程和多线程两种示例。其异同点如下 Due to language limitations, Python has the existence of GIL lock. In computing-intensive scenarios, multithreading cannot make full use of hardware resources. Therefore, two examples of multi-process and multi-thread are provided on Python. The similarities and differences are as follows:
### FastDeploy模型多进程与多线程推理的比较 ### Comparison of multi-process and multi-thread inference in FastDeploy model
| | 资源占用 | 计算密集型 | I/O密集型 | 进程或线程间通信 | | | resource usage | computationally intensive | I/O intensive | inter-process or inter-thread communication |
|:-------|:------|:----------|:----------|:----------| |:-------|:------|:----------|:----------|:----------|
| 多进程 | 大 | 快 | 快 | 慢| | multi-process | large | fast | fast | slow |
| 多线程 | 小 | 慢 | 较快 |快| | multi-thread | little | slow | relatively fast |fast|
>> **注意**:以上分析相对理论实际上Python针对不同的计算任务也做出了一定的优化像是numpy类的计算已经可以做到并行计算同时由于多进程间的result汇总涉及到进程间通信而且往往有时候很难鉴别该任务是计算密集型还是I/O密集型所以一切都需要根据任务进行测试而定。 >> **注意**: The above analysis is a theoretical analysis. In fact, Python has also made certain optimizations for different computing tasks. For example, the calculation of numpy can already be computed by multi-thread parallelly. In addition, the result aggregation between multiple processes involves time-consuming operation(inter-process communication), Besides, it is difficult to identify whether the task is computationally intensive or I/O intensive, so everything needs to be tested according to the task.
## C++多线程
C++的多线程,兼具了占用资源少,速度快的特点。因此,是使用多线程推理的最佳选择 ## C++ multi-thread
### C++ 多线程Clone与不Clone内存占用对比 The C++ multi-thread has the characteristics of occupying less resources and high speed.Therefore, multi-threaded inference is the best choice in C++
### C++ comparition between multi-thread Clone and not Clone memory occupation
硬件Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
模型ResNet50_vd_infer 模型ResNet50_vd_infer
后端CPU OPENVINO后端推理引擎 后端CPU OPENVINO Backend
单进程内初始化多个模型,内存占用 memory occupation of initializing multiple models in a single process
| 模型数 | model.Clone() | Clone后model->predict()后 | 不Clone模型初始化后| 不Clone后model->predict()后 | | number of models | after model.Clone() | after model->predict() with model.Clone() | initializing model without model.Clone()| after model->predict() without model.Clone() |
|:--- |:----- |:----- |:----- |:----- | |:--- |:----- |:----- |:----- |:----- |
|1|322M |325M |322M|325M| |1|322M |325M |322M|325M|
|2|322M|325M|559M|560M| |2|322M|325M|559M|560M|
|3|322M|325M|771M|771M| |3|322M|325M|771M|771M|
模型多线程预测内存占用 memory occupation of multi-thread
| 线程数 | model.Clone() | Clone后model->predict()后 | 不Clone模型初始化后| 不Clone后model->predict()后 | | thread number | after model.Clone() | after model->predict() with model.Clone() | initialize model without model.Clone() | after model->predict() without model.Clone() |
|:--- |:----- |:----- |:----- |:----- | |:--- |:----- |:----- |:----- |:----- |
|1|322M |337M |322M|337M| |1|322M |337M |322M|337M|
|2|322M|343M|548M|566M| |2|322M|343M|548M|566M|
|3|322M|347M|752M|784M| |3|322M|347M|752M|784M|

View File

@@ -0,0 +1,107 @@
[English](README.md) | 中文
# FastDeploy模型多线程或多进程预测的使用
FastDeploy针对python和cpp开发者提供了以下多线程或多进程的示例
- [python多线程以及多进程预测的使用示例](python)
- [cpp多线程预测的使用示例](cpp)
## 目前支持多线程以及多进程预测的模型
| 任务类型 | 说明 | 模型下载链接 |
|:-------------- |:----------------------------------- |:-------------------------------------------------------------------------------- |
| Detection | 支持PaddleDetection系列模型 | [PaddleDetection](../../examples/vision/detection/paddledetection) |
| Segmentation | 支持PaddleSeg系列模型 | [PaddleSeg](../../examples/vision/segmentation/paddleseg) |
| Classification | 支持PaddleClas系列模型 | [PaddleClas](../../examples/vision/classification/paddleclas) |
| OCR | 支持PaddleOCR系列模型 | [PaddleOCR](../../examples/vision/ocr/) |
>> **注意**:
- 点击上方模型下载链接,至`下载预训练模型`模块下载模型
- OCR是多模型串联的模型多线程示例请参考`pipeline`文件夹,其他单模型多线程示例在`single_model`文件夹中
## 多线程预测时克隆模型
针对一个视觉模型的推理包含3个环节
- 输入图像图像经过预处理最终得到要输入给模型Runtime的Tensor即preprocess阶段
- 模型Runtime接收Tensor进行推理得到Runtime的输出Tensor即infer阶段
- 对Runtime的输出Tensor做后处理得到最后的结构化信息如DetectionResult, SegmentationResult等等即postprocess阶段
针对以上preprocess、infer、postprocess三个阶段FastDeploy分别抽象出了三个对应的类即Preprocessor、Runtime、PostProcessor
在多线程调用FastDeploy中的模型进行并行推理的时候要考虑几个问题
- Preprocessor、Runtime、Postprocessor三个类能否分别支持并行处理
- 在支持多线程并发的前提下,能否最大限度的减少内存或显存占用
FastDeploy采用分别拷贝多个对象的方式进行多线程推理即每个线程都有一份独立的Preprocessor、Runtime、PostProcessor的实例化的对象。而为了减少内存的占用对于Runtime的拷贝则采用共享模型权重的方式进行拷贝。因此虽然复制了多个对象但对于模型权重和参数在内存或显存中只有一份。
以此减少拷贝多个对象带来的内存占用。
FastDeploy提供如下接口来进行模型的clone(以PaddleClas为例)
- Python: `PaddleClasModel.clone()`
- C++: `PaddleClasModel::Clone()`
### Python
```
import fastdeploy as fd
option = fd.RuntimeOption()
model = fd.vision.classification.PaddleClasModel(model_file,
params_file,
config_file,
runtime_option=option)
model2 = model.clone()
im = cv2.imread(image)
res = model.predict(im)
```
### C++
```
auto model = fastdeploy::vision::classification::PaddleClasModel(model_file,
params_file,
config_file,
option);
auto model2 = model.Clone();
auto im = cv::imread(image_file);
fastdeploy::vision::ClassifyResult res;
model->Predict(im, &res)
```
>> **注意**:其他模型类似API接口可查阅[官方C++文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/cpp/html/index.html)以及[官方Python文档](https://www.paddlepaddle.org.cn/fastdeploy-api-doc/python/html/index.html)
## Python多线程以及多进程
Python由于语言的限制即GIL锁的存在在计算密集型的场景下多线程无法充分利用硬件的性能。因此Python上提供多进程和多线程两种示例。其异同点如下
### FastDeploy模型多进程与多线程推理的比较
| | 资源占用 | 计算密集型 | I/O密集型 | 进程或线程间通信 |
|:-------|:------|:----------|:----------|:----------|
| 多进程 | 大 | 快 | 快 | 慢|
| 多线程 | 小 | 慢 | 较快 |快|
>> **注意**:以上分析相对理论实际上Python针对不同的计算任务也做出了一定的优化像是numpy类的计算已经可以做到并行计算同时由于多进程间的result汇总涉及到进程间通信而且往往有时候很难鉴别该任务是计算密集型还是I/O密集型所以一切都需要根据任务进行测试而定。
## C++多线程
C++的多线程,兼具了占用资源少,速度快的特点。因此,是使用多线程推理的最佳选择
### C++ 多线程Clone与不Clone内存占用对比
硬件Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
模型ResNet50_vd_infer
后端CPU OPENVINO后端推理引擎
单进程内初始化多个模型,内存占用
| 模型数 | model.Clone()后 | Clone后model->predict()后 | 不Clone模型初始化后| 不Clone后model->predict()后 |
|:--- |:----- |:----- |:----- |:----- |
|1|322M |325M |322M|325M|
|2|322M|325M|559M|560M|
|3|322M|325M|771M|771M|
模型多线程预测内存占用
| 线程数 | model.Clone()后 | Clone后model->predict()后 | 不Clone模型初始化后| 不Clone后model->predict()后 |
|:--- |:----- |:----- |:----- |:----- |
|1|322M |337M |322M|337M|
|2|322M|343M|548M|566M|
|3|322M|347M|752M|784M|

View File

@@ -0,0 +1,14 @@
PROJECT(multi_thread_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(multi_thread_demo ${PROJECT_SOURCE_DIR}/multi_thread_ocr.cc)
# 添加FastDeploy库依赖
target_link_libraries(multi_thread_demo ${FASTDEPLOY_LIBS} pthread)

View File

@@ -0,0 +1,59 @@
English | [简体中文](README_CN.md)
# PPOCRv3 C++ multi-thread Deployment Example
This directory provides examples file `multi_thread_ocr.cc` to fast deploy PPOCRv3 on CPU/GPU and GPU accelerated by TensorRT.
Two steps before deployment
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
Taking the PPOCRv3 inference on Linux as an example, the compilation test can be completed by executing the following command in this directory. FastDeploy version 0.7.0 or above (x.x.x>=0.7.0) is required to support this model.
```bash
mkdir build
cd build
# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j
# Download model, image, and dictionary files
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar -xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar -xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar -xvf ch_PP-OCRv3_rec_infer.tar
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
# CPU multi-thread inference
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0 1
# GPU multi-thread inference
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1 1
# TensorRT multi-thread inference on GPU
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2 1
# Paddle-TRT multi-thread inference on GPU
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3 1
# KunlunXin XPU multi-thread inference
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4 1
>> **Notice**: the last number in above command is thread number
The above command works for Linux or MacOS. For SDK in Windows, refer to:
- [How to use FastDeploy C++ SDK in Windows](../../../docs/cn/faq/use_sdk_on_windows.md)
The result returned after running is as follows
```
Thread Id: 0
det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.980085 cls label: 0 cls score: 1.000000
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.964993 cls label: 0 cls score: 1.000000
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.993727 cls label: 0 cls score: 1.000000
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.947723 cls label: 0 cls score: 1.000000
```

View File

@@ -0,0 +1,59 @@
[English](README.md) | 中文
# PPOCRv3模型 C++多线程部署示例
本目录下提供`multi_thread_ocr.cc`快速完成PPOCRv3系列模型在CPU/GPU以及GPU上通过TensorRT加速多线程部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上ResNet50_vd推理为例在本目录执行如下命令即可完成编译测试支持此模型需保证FastDeploy版本0.7.0以上(x.x.x>=0.7.0)
```bash
mkdir build
cd build
# 下载FastDeploy预编译库用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j
# 下载模型,图片和字典文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar -xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar -xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar -xvf ch_PP-OCRv3_rec_infer.tar
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
# CPU推理
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 0 1
# GPU推理
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 1 1
# GPU上TensorRT推理
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 2 1
# GPU上Paddle-TRT推理
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 3 1
# 昆仑芯XPU推理
./multi_thread_demo ./ch_PP-OCRv3_det_infer ./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer ./ppocr_keys_v1.txt ./12.jpg 4 1
>> **注意**: 最后一位数字表示线程数
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../docs/cn/faq/use_sdk_on_windows.md)
运行完成后返回结果如下所示
```
Thread Id: 0
det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.980085 cls label: 0 cls score: 1.000000
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.964993 cls label: 0 cls score: 1.000000
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.993727 cls label: 0 cls score: 1.000000
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.947723 cls label: 0 cls score: 1.000000
```

View File

@@ -0,0 +1,177 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread>
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void Predict(fastdeploy::pipeline::PPOCRv3 *model, int thread_id, const std::vector<std::string>& images) {
for (auto const &image_file : images) {
auto im = cv::imread(image_file);
fastdeploy::vision::OCRResult res;
if (!model->Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// print res
std::cout << "Thread Id: " << thread_id << std::endl;
std::cout << res.Str() << std::endl;
}
}
void GetImageList(std::vector<std::vector<std::string>>* image_list, const std::string& image_file_path, int thread_num){
std::vector<cv::String> images;
cv::glob(image_file_path, images, false);
// number of image files in images folder
size_t count = images.size();
size_t num = count / thread_num;
for (int i = 0; i < thread_num; i++) {
std::vector<std::string> temp_list;
if (i == thread_num - 1) {
for (size_t j = i*num; j < count; j++){
temp_list.push_back(images[j]);
}
} else {
for (size_t j = 0; j < num; j++){
temp_list.push_back(images[i * num + j]);
}
}
(*image_list)[i] = temp_list;
}
}
void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model_dir, const std::string& rec_model_dir, const std::string& rec_label_file, const std::string& image_file_path, const fastdeploy::RuntimeOption& option, int thread_num) {
auto det_model_file = det_model_dir + sep + "inference.pdmodel";
auto det_params_file = det_model_dir + sep + "inference.pdiparams";
auto cls_model_file = cls_model_dir + sep + "inference.pdmodel";
auto cls_params_file = cls_model_dir + sep + "inference.pdiparams";
auto rec_model_file = rec_model_dir + sep + "inference.pdmodel";
auto rec_params_file = rec_model_dir + sep + "inference.pdiparams";
auto det_option = option;
auto cls_option = option;
auto rec_option = option;
// The cls and rec model can inference a batch of images now.
// User could initialize the inference batch size and set them after create PP-OCR model.
int cls_batch_size = 1;
int rec_batch_size = 6;
// If use TRT backend, the dynamic shape will be set as follow.
// We recommend that users set the length and height of the detection model to a multiple of 32.
// We also recommend that users set the Trt input shape as follow.
det_option.SetTrtInputShape("x", {1, 3, 64,64}, {1, 3, 640, 640},
{1, 3, 960, 960});
cls_option.SetTrtInputShape("x", {1, 3, 48, 10}, {cls_batch_size, 3, 48, 320}, {cls_batch_size, 3, 48, 1024});
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {rec_batch_size, 3, 48, 320},
{rec_batch_size, 3, 48, 2304});
// Users could save TRT cache file to disk as follow.
// det_option.SetTrtCacheFile(det_model_dir + sep + "det_trt_cache.trt");
// cls_option.SetTrtCacheFile(cls_model_dir + sep + "cls_trt_cache.trt");
// rec_option.SetTrtCacheFile(rec_model_dir + sep + "rec_trt_cache.trt");
auto det_model = fastdeploy::vision::ocr::DBDetector(det_model_file, det_params_file, det_option);
auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option);
auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option);
assert(det_model.Initialized());
assert(cls_model.Initialized());
assert(rec_model.Initialized());
// The classification model is optional, so the PP-OCR can also be connected in series as follows
// auto ppocr_v3 = fastdeploy::pipeline::PPOCRv3(&det_model, &rec_model);
auto ppocr_v3 = fastdeploy::pipeline::PPOCRv3(&det_model, &cls_model, &rec_model);
// Set inference batch size for cls model and rec model, the value could be -1 and 1 to positive infinity.
// When inference batch size is set to -1, it means that the inference batch size
// of the cls and rec models will be the same as the number of boxes detected by the det model.
ppocr_v3.SetClsBatchSize(cls_batch_size);
ppocr_v3.SetRecBatchSize(rec_batch_size);
if(!ppocr_v3.Initialized()){
std::cerr << "Failed to initialize PP-OCR." << std::endl;
return;
}
std::vector<decltype(ppocr_v3.Clone())> models;
for (int i = 0; i < thread_num; ++i) {
models.emplace_back(std::move(ppocr_v3.Clone()));
}
std::vector<std::vector<std::string>> image_list(thread_num);
GetImageList(&image_list, image_file_path, thread_num);
std::vector<std::thread> threads;
for (int i = 0; i < thread_num; ++i) {
threads.emplace_back(Predict, models[i].get(), i, image_list[i]);
}
for (int i = 0; i < thread_num; ++i) {
threads[i].join();
}
}
int main(int argc, char* argv[]) {
if (argc < 7) {
std::cout << "Usage: infer_demo path/to/det_model path/to/cls_model "
"path/to/rec_model path/to/rec_label_file path/to/image "
"run_option thread_num,"
"e.g ./infer_demo ./ch_PP-OCRv3_det_infer "
"./ch_ppocr_mobile_v2.0_cls_infer ./ch_PP-OCRv3_rec_infer "
"./ppocr_keys_v1.txt ./12.jpg 0 3"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend; 3: run with gpu and use Paddle-TRT; 4: run with kunlunxin."
<< std::endl;
return -1;
}
fastdeploy::RuntimeOption option;
int flag = std::atoi(argv[6]);
if (flag == 0) {
option.UseCpu();
} else if (flag == 1) {
option.UseGpu();
} else if (flag == 2) {
option.UseGpu();
option.UseTrtBackend();
} else if (flag == 3) {
option.UseGpu();
option.UseTrtBackend();
option.EnablePaddleTrtCollectShape();
option.EnablePaddleToTrt();
} else if (flag == 4) {
option.UseKunlunXin();
}
std::string det_model_dir = argv[1];
std::string cls_model_dir = argv[2];
std::string rec_model_dir = argv[3];
std::string rec_label_file = argv[4];
std::string image_file_path = argv[5];
int thread_num = std::atoi(argv[7]);
InitAndInfer(det_model_dir, cls_model_dir, rec_model_dir, rec_label_file, image_file_path, option, thread_num);
return 0;
}

View File

@@ -0,0 +1,48 @@
English | [中文]((README_CN.md))
# Example of PaddleClas models Python Deployment
This directory provides example file `multi_thread.cc` to fast deploy PaddleClas models on CPU/GPU and GPU accelerated by TensorRT.
Before deployment, two steps require confirmation.
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. Install the FastDeploy Python whl package. Please refer to [FastDeploy Python Installation](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
Taking ResNet50_vd inference on Linux as an example, the compilation test can be completed by executing the following command in this directory. FastDeploy version 0.7.0 or above (x.x.x>=0.7.0) is required to support this model.
```bash
mkdir build
cd build
# # Download FastDeploy precompiled library. Users can choose your appropriate version in the`FastDeploy Precompiled Library` mentioned above
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j
# Download the ResNet50_vd model file and test images
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# CPU multi-thread inference
./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 0 1
# GPU multi-thread inference
./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 1 1
# TensorRT multi-inference inference on GPU
./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 2 1
```
>> **Notice**: the last number in above command is thread number
The above command works for Linux or MacOS. For SDK in Windows, refer to:
- [How to use FastDeploy C++ SDK in Windows ](../../../docs/cn/faq/use_sdk_on_windows.md)
The result returned after running is as follows
```
Thread Id: 0
ClassifyResult(
label_ids: 153,
scores: 0.686229,
)
```

View File

@@ -1,11 +1,13 @@
[English](README.md) | 中文
# PaddleClas C++多线程部署示例 # PaddleClas C++多线程部署示例
本目录下提供`multi_thread.cc`快速完成PaddleClas系列模型在CPU/GPU以及GPU上通过TensorRT加速多线程部署的示例。 本目录下提供`multi_thread.cc`快速完成PaddleClas系列模型在CPU/GPU以及GPU上通过TensorRT加速多线程部署的示例。
在部署前,需确认以下两个步骤 在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../docs/cn/build_and_install/download_prebuilt_libraries.md) - 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../docs/cn/build_and_install/download_prebuilt_libraries.md) - 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上ResNet50_vd推理为例在本目录执行如下命令即可完成编译测试支持此模型需保证FastDeploy版本0.7.0以上(x.x.x>=0.7.0) 以Linux上ResNet50_vd推理为例在本目录执行如下命令即可完成编译测试支持此模型需保证FastDeploy版本0.7.0以上(x.x.x>=0.7.0)
@@ -25,13 +27,22 @@ wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/Ima
# CPU多线程推理 # CPU多线程推理
./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 0 1 ./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 0 1
# GPU多线程推理 # GPU多线程推理
./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 1 1 ./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 1 1
# GPU上TensorRT多线程推理 # GPU上TensorRT多线程推理
./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 2 1 ./multi_thread_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 2 1
``` ```
>> **注意**: 最后一位数字表示线程数 >> **注意**: 最后一位数字表示线程数
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: 以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../docs/cn/faq/use_sdk_on_windows.md) - [如何在Windows中使用FastDeploy C++ SDK](../../../docs/cn/faq/use_sdk_on_windows.md)
运行完成后返回结果如下所示
```
Thread Id: 0
ClassifyResult(
label_ids: 153,
scores: 0.686229,
)
```

View File

@@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread> #include <thread>
#include "fastdeploy/vision.h" #include "fastdeploy/vision.h"
#ifdef WIN32 #ifdef WIN32

View File

@@ -0,0 +1,60 @@
English | [简体中文](README_CN.md)
# PPOCRv3 Python multi-thread/multi-process Deployment Example
Two steps before deployment
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. Install FastDeploy Python whl package. Refer to [FastDeploy Python Installation](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
This directory provides example file `multi_thread_process_ocr.py` to fast deploy multi-thread/multi-process ResNet50_vd on CPU/GPU and GPU accelerated by TensorRT. The script is as follows
```bash
# Download deployment example code
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/tutorials/multi_thread/python/pipeline
# Download model, image, and dictionary files
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar -xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar xvf ch_PP-OCRv3_rec_infer.tar
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
# CPU multi-thread inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device cpu --thread_num 1
# CPU multi-process inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device cpu --use_multi_process True --process_num 1
# GPU multi-thread inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --thread_num 1
# GPU multi-process inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --use_multi_process True --process_num 1
# Use TensorRT multi-thread inference on GPU Attention: It is somewhat time-consuming for the operation of model serialization when running TensorRT inference for the first time. Please be patient.
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --backend trt --thread_num 1
# Use TensorRT multi-process inference on GPU Attention: It is somewhat time-consuming for the operation of model serialization when running TensorRT inference for the first time. Please be patient.
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --backend trt --use_multi_process True --process_num 1
# KunlunXin XPU multi-thread inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device kunlunxin --thread_num 1
# KunlunXin XPU multi-process inference
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device kunlunxin --use_multi_process True --process_num 1
```
>> **Notice**: `--image_path` can be the path of the pictures folder
The result returned after running is as follows
```
thread: 0 , result: det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.949773 cls label: 0 cls score: 1.000000
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.910265 cls label: 0 cls score: 1.000000
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.934239 cls label: 0 cls score: 1.000000
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.872207 cls label: 0 cls score: 1.000000
```

View File

@@ -0,0 +1,60 @@
[English](README.md) | 简体中文
# PPOCR模型 Python多线程/进程部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供`multi_thread_process_ocr.py`快速完成PPOCRv3在CPU/GPU以及GPU上通过TensorRT加速部署的多线程/进程示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/tutorials/multi_thread/python/pipeline
# 下载模型,图片和字典文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar -xvf ch_ppocr_mobile_v2.0_cls_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar xvf ch_PP-OCRv3_rec_infer.tar
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
# CPU多线程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device cpu --thread_num 1
# CPU多进程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device cpu --use_multi_process True --process_num 1
# GPU多线程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --thread_num 1
# GPU多进程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --use_multi_process True --process_num 1
# GPU上使用TensorRT多线程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --backend trt --thread_num 1
# GPU上使用TensorRT多进程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device gpu --backend trt --use_multi_process True --process_num 1
# 昆仑芯XPU多线程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device kunlunxin --thread_num 1
# 昆仑芯XPU多进程推理
python multi_thread_process_ocr.py --det_model ch_PP-OCRv3_det_infer --cls_model ch_ppocr_mobile_v2.0_cls_infer --rec_model ch_PP-OCRv3_rec_infer --rec_label_file ppocr_keys_v1.txt --image_path 12.jpg --device kunlunxin --use_multi_process True --process_num 1
```
>> **注意**: `--image_path` 可以输入图片文件夹的路径
运行完成后返回结果如下所示
```
thread: 0 , result: det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.949773 cls label: 0 cls score: 1.000000
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.910265 cls label: 0 cls score: 1.000000
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.934239 cls label: 0 cls score: 1.000000
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.872207 cls label: 0 cls score: 1.000000
```

View File

@@ -0,0 +1,279 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from threading import Thread
import fastdeploy as fd
import cv2
import os
from multiprocessing import Pool
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model", required=True, help="Path of Detection model of PPOCR.")
parser.add_argument(
"--cls_model",
required=True,
help="Path of Classification model of PPOCR.")
parser.add_argument(
"--rec_model",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--rec_label_file",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--image_path",
type=str,
required=True,
help="The directory or path or file list of the images to be predicted."
)
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.")
parser.add_argument(
"--backend",
type=str,
default="default",
help="Type of inference backend, support ort/trt/paddle/openvino, default 'openvino' for cpu, 'tensorrt' for gpu"
)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="Define which GPU card used to run model.")
parser.add_argument(
"--cpu_thread_num",
type=int,
default=9,
help="Number of threads while inference on CPU.")
parser.add_argument(
"--cls_bs",
type=int,
default=1,
help="Classification model inference batch size.")
parser.add_argument(
"--rec_bs",
type=int,
default=6,
help="Recognition model inference batch size")
parser.add_argument("--thread_num", type=int, default=1, help="thread num")
parser.add_argument(
"--use_multi_process",
type=ast.literal_eval,
default=False,
help="Wether to use multi process.")
parser.add_argument(
"--process_num", type=int, default=1, help="process num")
return parser.parse_args()
def get_image_list(image_path):
image_list = []
if os.path.isfile(image_path):
image_list.append(image_path)
# load image in a directory
elif os.path.isdir(image_path):
for root, dirs, files in os.walk(image_path):
for f in files:
image_list.append(os.path.join(root, f))
else:
raise FileNotFoundError(
'{} is not found. it should be a path of image, or a directory including images.'.
format(image_path))
if len(image_list) == 0:
raise RuntimeError(
'There are not image file in `--image_path`={}'.format(image_path))
return image_list
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu(args.device_id)
option.set_cpu_thread_num(args.cpu_thread_num)
if args.device.lower() == "kunlunxin":
option.use_kunlunxin()
return option
if args.backend.lower() == "trt":
assert args.device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
option.use_trt_backend()
elif args.backend.lower() == "pptrt":
assert args.device.lower(
) == "gpu", "Paddle-TensorRT backend require inference on device GPU."
option.use_trt_backend()
option.enable_paddle_trt_collect_shape()
option.enable_paddle_to_trt()
elif args.backend.lower() == "ort":
option.use_ort_backend()
elif args.backend.lower() == "paddle":
option.use_paddle_infer_backend()
elif args.backend.lower() == "openvino":
assert args.device.lower(
) == "cpu", "OpenVINO backend require inference on device CPU."
option.use_openvino_backend()
return option
def load_model(args, runtime_option):
# Detection模型, 检测文字框
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
# Classification模型方向分类可选
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
# Recognition模型文字识别模型
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
rec_label_file = args.rec_label_file
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
cls_batch_size = 1
rec_batch_size = 6
# 当使用TRT时分别给三个模型的runtime设置动态shape,并完成模型的创建.
# 注意: 需要在检测模型创建完成后,再设置分类模型的动态输入并创建分类模型, 识别模型同理.
# 如果用户想要自己改动检测模型的输入shape, 我们建议用户把检测模型的长和高设置为32的倍数.
det_option = runtime_option
det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
[1, 3, 960, 960])
# 用户可以把TRT引擎文件保存至本地
#det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
global det_model
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=det_option)
cls_option = runtime_option
cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
[cls_batch_size, 3, 48, 320],
[cls_batch_size, 3, 48, 1024])
# 用户可以把TRT引擎文件保存至本地
#cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
global cls_model
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=cls_option)
rec_option = runtime_option
rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
[rec_batch_size, 3, 48, 320],
[rec_batch_size, 3, 48, 2304])
# 用户可以把TRT引擎文件保存至本地
#rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
global rec_model
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_option)
# 创建PP-OCR串联3个模型其中cls_model可选如无需求可设置为None
global ppocr_v3
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
# 给cls和rec模型设置推理时的batch size
# 此值能为-1, 和1到正无穷
# 当此值为-1时, cls和rec模型的batch size将默认和det模型检测出的框的数量相同
ppocr_v3.cls_batch_size = cls_batch_size
ppocr_v3.rec_batch_size = rec_batch_size
def predict(model, img_list):
result_list = []
# predict ppocr result
for image in img_list:
im = cv2.imread(image)
result = model.predict(im)
result_list.append(result)
return result_list
def process_predict(image):
# predict ppocr result
im = cv2.imread(image)
result = ppocr_v3.predict(im)
print(result)
class WrapperThread(Thread):
def __init__(self, func, args):
super(WrapperThread, self).__init__()
self.func = func
self.args = args
def run(self):
self.result = self.func(*self.args)
def get_result(self):
return self.result
if __name__ == '__main__':
args = parse_arguments()
imgs_list = get_image_list(args.image_path)
# 对于三个模型,均采用同样的部署配置
# 用户也可根据自行需求分别配置
runtime_option = build_option(args)
if args.use_multi_process:
process_num = args.process_num
with Pool(
process_num,
initializer=load_model,
initargs=(args, runtime_option)) as pool:
pool.map(process_predict, imgs_list)
else:
load_model(args, runtime_option)
threads = []
thread_num = args.thread_num
image_num_each_thread = int(len(imgs_list) / thread_num)
# unless you want independent model in each thread, actually model.clone()
# is the same as model when creating thead because of the existence of
# GIL(Global Interpreter Lock) in python. In addition, model.clone() will consume
# additional memory to store independent member variables
for i in range(thread_num):
if i == thread_num - 1:
t = WrapperThread(
predict,
args=(ppocr_v3.clone(),
imgs_list[i * image_num_each_thread:]))
else:
t = WrapperThread(
predict,
args=(ppocr_v3.clone(),
imgs_list[i * image_num_each_thread:(i + 1) *
image_num_each_thread - 1], args.topk))
threads.append(t)
t.start()
for i in range(thread_num):
threads[i].join()
for i in range(thread_num):
for result in threads[i].get_result():
print('thread:', i, ', result: ', result)

View File

@@ -0,0 +1,51 @@
English | [简体中文](README_CN.md)
# Example of PaddleClas models Python multi-thread/multi-process Deployment
Before deployment, two steps require confirmation
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. Install the FastDeploy Python whl package. Please refer to [FastDeploy Python Installation](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
This directory provides example file `multi_thread_process.py` to fast deploy multi-thread/multi-process ResNet50_vd on CPU/GPU and GPU accelerated by TensorRT. The script is as follows
```bash
# Download deployment example code
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/tutorials/multi_thread/python
# Download the ResNet50_vd model file and test images
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# CPU multi-thread inference
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --thread_num 1
# CPU multi-process inference
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --use_multi_process True --process_num 1
# GPU multi-thread inference
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --thread_num 1
# GPU multi-process inference
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --use_multi_process True --process_num 1
# Use TensorRT multi-thread inference on GPU Attention: It is somewhat time-consuming for the operation of model serialization when running TensorRT inference for the first time. Please be patient.
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --thread_num 1
# Use TensorRT multi-process inference on GPU Attention: It is somewhat time-consuming for the operation of model serialization when running TensorRT inference for the first time. Please be patient.
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --use_multi_process True --process_num 1
# IPU multi-thread inferenceAttention: It is somewhat time-consuming for the operation of model serialization when running IPU inference for the first time. Please be patient.
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --thread_num 1
# IPU multi-process inferenceAttention: It is somewhat time-consuming for the operation of model serialization when running IPU inference for the first time. Please be patient.
python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --use_multi_process True --process_num 1
```
>> **Notice**: `--image_path` can be the path of the pictures folder
The result returned after running is as follows
```bash
ClassifyResult(
label_ids: 153,
scores: 0.686229,
)
```

View File

@@ -1,9 +1,10 @@
[English](README.md) | 简体中文
# PaddleClas模型 Python多线程/进程部署示例 # PaddleClas模型 Python多线程/进程部署示例
在部署前,需确认以下两个步骤 在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../docs/cn/build_and_install/download_prebuilt_libraries.md) - 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../docs/cn/build_and_install/download_prebuilt_libraries.md) - 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供`multi_thread_process.py`快速完成ResNet50_vd在CPU/GPU以及GPU上通过TensorRT加速部署的多线程/进程示例。执行如下脚本即可完成 本目录下提供`multi_thread_process.py`快速完成ResNet50_vd在CPU/GPU以及GPU上通过TensorRT加速部署的多线程/进程示例。执行如下脚本即可完成
@@ -20,24 +21,24 @@ wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/Ima
# CPU多线程推理 # CPU多线程推理
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --thread_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --thread_num 1
# CPU多进程推理 # CPU多进程推理
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --use_multi_process True --process_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 --use_multi_process True --process_num 1
# GPU多线程推理 # GPU多线程推理
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --thread_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --thread_num 1
# GPU多进程推理 # GPU多进程推理
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --use_multi_process True --process_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 --use_multi_process True --process_num 1
# GPU上使用TensorRT多线程推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待 # GPU上使用TensorRT多线程推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --thread_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --thread_num 1
# GPU上使用TensorRT多进程推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待 # GPU上使用TensorRT多进程推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --use_multi_process True --process_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 --use_multi_process True --process_num 1
# IPU多线程推理注意IPU推理首次运行会有序列化模型的操作有一定耗时需要耐心等待 # IPU多线程推理注意IPU推理首次运行会有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --thread_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --thread_num 1
# IPU多进程推理注意IPU推理首次运行会有序列化模型的操作有一定耗时需要耐心等待 # IPU多进程推理注意IPU推理首次运行会有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --use_multi_process True --process_num 1 python multi_thread_process.py --model ResNet50_vd_infer --image_path ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 --use_multi_process True --process_num 1
``` ```
>> **注意**: `--image_path` 可以输入图片文件夹的路径 >> **注意**: `--image_path` 可以输入图片文件夹的路径

View File

@@ -1,4 +1,17 @@
import numpy as np # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from threading import Thread from threading import Thread
import fastdeploy as fd import fastdeploy as fd
import cv2 import cv2
@@ -67,6 +80,7 @@ def build_option(args):
option = fd.RuntimeOption() option = fd.RuntimeOption()
if args.device.lower() == "gpu": if args.device.lower() == "gpu":
option.use_paddle_backend()
option.use_gpu() option.use_gpu()
if args.device.lower() == "ipu": if args.device.lower() == "ipu":
@@ -77,6 +91,16 @@ def build_option(args):
return option return option
def load_model(args, runtime_option):
model_file = os.path.join(args.model, "inference.pdmodel")
params_file = os.path.join(args.model, "inference.pdiparams")
config_file = os.path.join(args.model, "inference_cls.yaml")
global model
model = fd.vision.classification.PaddleClasModel(
model_file, params_file, config_file, runtime_option=runtime_option)
#return model
def predict(model, img_list, topk): def predict(model, img_list, topk):
result_list = [] result_list = []
# predict classification result # predict classification result
@@ -91,7 +115,7 @@ def process_predict(image):
# predict classification result # predict classification result
im = cv2.imread(image) im = cv2.imread(image)
result = model.predict(im, args.topk) result = model.predict(im, args.topk)
return result print(result)
class WrapperThread(Thread): class WrapperThread(Thread):
@@ -114,19 +138,15 @@ if __name__ == '__main__':
# configure runtime and load model # configure runtime and load model
runtime_option = build_option(args) runtime_option = build_option(args)
model_file = os.path.join(args.model, "inference.pdmodel")
params_file = os.path.join(args.model, "inference.pdiparams")
config_file = os.path.join(args.model, "inference_cls.yaml")
model = fd.vision.classification.PaddleClasModel(
model_file, params_file, config_file, runtime_option=runtime_option)
if args.use_multi_process: if args.use_multi_process:
results = []
process_num = args.process_num process_num = args.process_num
with Pool(process_num) as pool: with Pool(
results = pool.map(process_predict, imgs_list) process_num,
for result in results: initializer=load_model,
print(result) initargs=(args, runtime_option)) as pool:
pool.map(process_predict, imgs_list)
else: else:
load_model(args, runtime_option)
threads = [] threads = []
thread_num = args.thread_num thread_num = args.thread_num
image_num_each_thread = int(len(imgs_list) / thread_num) image_num_each_thread = int(len(imgs_list) / thread_num)