From ada54bfd471b564f77ce57498d6dd997000b4080 Mon Sep 17 00:00:00 2001 From: huangjianhui <852142024@qq.com> Date: Wed, 14 Dec 2022 19:18:53 +0800 Subject: [PATCH] [Other]Update python && cpp multi_thread examples (#876) * Refactor PaddleSeg with preprocessor && postprocessor * Fix bugs * Delete redundancy code * Modify by comments * Refactor according to comments * Add batch evaluation * Add single test script * Add ppliteseg single test script && fix eval(raise) error * fix bug * Fix evaluation segmentation.py batch predict * Fix segmentation evaluation bug * Fix evaluation segmentation bugs * Update segmentation result docs * Update old predict api and DisableNormalizeAndPermute * Update resize segmentation label map with cv::INTER_NEAREST * Add Model Clone function for PaddleClas && PaddleDet && PaddleSeg * Add multi thread demo * Add python model clone function * Add multi thread python && C++ example * Fix bug * Update python && cpp multi_thread examples * Add cpp && python directory * Add README.md for examples * Delete redundant code Co-authored-by: Jason --- fastdeploy/vision/vision_pybind.cc | 65 +++++++++++++++ tutorials/multi_thread/cpp/CMakeLists.txt | 14 ++++ tutorials/multi_thread/cpp/README.md | 79 +++++++++++++++++++ .../{ => multi_thread/cpp}/multi_thread.cc | 66 ++++++++++++---- tutorials/multi_thread/python/README.md | 77 ++++++++++++++++++ .../python/multi_thread_process.py} | 72 +++++++++++------ 6 files changed, 334 insertions(+), 39 deletions(-) create mode 100644 tutorials/multi_thread/cpp/CMakeLists.txt create mode 100644 tutorials/multi_thread/cpp/README.md rename tutorials/{ => multi_thread/cpp}/multi_thread.cc (64%) create mode 100644 tutorials/multi_thread/python/README.md rename tutorials/{multi_thread.py => multi_thread/python/multi_thread_process.py} (59%) diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 55312d1a3..cecd4f7c3 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -37,6 +37,21 @@ void BindVision(pybind11::module& m) { .def(pybind11::init()) .def_readwrite("data", &vision::Mask::data) .def_readwrite("shape", &vision::Mask::shape) + .def(pybind11::pickle( + [](const vision::Mask &m) { + return pybind11::make_tuple(m.data, m.shape); + }, + [](pybind11::tuple t) { + if (t.size() != 2) + throw std::runtime_error("vision::Mask pickle with invalid state!"); + + vision::Mask m; + m.data = t[0].cast>(); + m.shape = t[1].cast>(); + + return m; + } + )) .def("__repr__", &vision::Mask::Str) .def("__str__", &vision::Mask::Str); @@ -44,6 +59,21 @@ void BindVision(pybind11::module& m) { .def(pybind11::init()) .def_readwrite("label_ids", &vision::ClassifyResult::label_ids) .def_readwrite("scores", &vision::ClassifyResult::scores) + .def(pybind11::pickle( + [](const vision::ClassifyResult &c) { + return pybind11::make_tuple(c.label_ids, c.scores); + }, + [](pybind11::tuple t) { + if (t.size() != 2) + throw std::runtime_error("vision::ClassifyResult pickle with invalid state!"); + + vision::ClassifyResult c; + c.label_ids = t[0].cast>(); + c.scores = t[1].cast>(); + + return c; + } + )) .def("__repr__", &vision::ClassifyResult::Str) .def("__str__", &vision::ClassifyResult::Str); @@ -54,6 +84,24 @@ void BindVision(pybind11::module& m) { .def_readwrite("label_ids", &vision::DetectionResult::label_ids) .def_readwrite("masks", &vision::DetectionResult::masks) .def_readwrite("contain_masks", &vision::DetectionResult::contain_masks) + .def(pybind11::pickle( + [](const vision::DetectionResult &d) { + return pybind11::make_tuple(d.boxes, d.scores, d.label_ids, d.masks, d.contain_masks); + }, + [](pybind11::tuple t) { + if (t.size() != 5) + throw std::runtime_error("vision::DetectionResult pickle with Invalid state!"); + + vision::DetectionResult d; + d.boxes = t[0].cast>>(); + d.scores = t[1].cast>(); + d.label_ids = t[2].cast>(); + d.masks = t[3].cast>(); + d.contain_masks = t[4].cast(); + + return d; + } + )) .def("__repr__", &vision::DetectionResult::Str) .def("__str__", &vision::DetectionResult::Str); @@ -104,6 +152,23 @@ void BindVision(pybind11::module& m) { .def_readwrite("score_map", &vision::SegmentationResult::score_map) .def_readwrite("shape", &vision::SegmentationResult::shape) .def_readwrite("contain_score_map", &vision::SegmentationResult::contain_score_map) + .def(pybind11::pickle( + [](const vision::SegmentationResult &s) { + return pybind11::make_tuple(s.label_map, s.score_map, s.shape, s.contain_score_map); + }, + [](pybind11::tuple t) { + if (t.size() != 4) + throw std::runtime_error("vision::SegmentationResult pickle with Invalid state!"); + + vision::SegmentationResult s; + s.label_map = t[0].cast>(); + s.score_map = t[1].cast>(); + s.shape = t[2].cast>(); + s.contain_score_map = t[3].cast(); + + return s; + } + )) .def("__repr__", &vision::SegmentationResult::Str) .def("__str__", &vision::SegmentationResult::Str); diff --git a/tutorials/multi_thread/cpp/CMakeLists.txt b/tutorials/multi_thread/cpp/CMakeLists.txt new file mode 100644 index 000000000..d6882f897 --- /dev/null +++ b/tutorials/multi_thread/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(multi_thread_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(multi_thread_demo ${PROJECT_SOURCE_DIR}/multi_thread.cc) +# 添加FastDeploy库依赖 +target_link_libraries(multi_thread_demo ${FASTDEPLOY_LIBS} pthread) diff --git a/tutorials/multi_thread/cpp/README.md b/tutorials/multi_thread/cpp/README.md new file mode 100644 index 000000000..066340467 --- /dev/null +++ b/tutorials/multi_thread/cpp/README.md @@ -0,0 +1,79 @@ +# PaddleClas C++部署示例 + +本目录下提供`infer.cc`快速完成PaddleClas系列模型在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上ResNet50_vd推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证FastDeploy版本0.7.0以上(x.x.x>=0.7.0) + +```bash +mkdir build +cd build +# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用 +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz +tar xvf fastdeploy-linux-x64-x.x.x.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x +make -j + +# 下载ResNet50_vd模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz +tar -xvf ResNet50_vd_infer.tgz +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + + +# CPU推理 +./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 0 +# GPU推理 +./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 1 +# GPU上TensorRT推理 +./infer_demo ResNet50_vd_infer ILSVRC2012_val_00000010.jpeg 2 +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## PaddleClas C++接口 + +### PaddleClas类 + +```c++ +fastdeploy::vision::classification::PaddleClasModel( + const string& model_file, + const string& params_file, + const string& config_file, + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) +``` + +PaddleClas模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 推理部署配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + +#### Predict函数 + +> ```c++ +> PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk = 1) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 分类结果,包括label_id,以及相应的置信度, ClassifyResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/tutorials/multi_thread.cc b/tutorials/multi_thread/cpp/multi_thread.cc similarity index 64% rename from tutorials/multi_thread.cc rename to tutorials/multi_thread/cpp/multi_thread.cc index 6cc01b5d3..9c9d0ec5a 100644 --- a/tutorials/multi_thread.cc +++ b/tutorials/multi_thread/cpp/multi_thread.cc @@ -6,21 +6,44 @@ const char sep = '\\'; const char sep = '/'; #endif -void predict(fastdeploy::vision::classification::PaddleClasModel *model, int thread_id, const std::string& image_file) { - auto im = cv::imread(image_file); +void Predict(fastdeploy::vision::classification::PaddleClasModel *model, int thread_id, const std::vector& images) { + for (auto const &image_file : images) { + auto im = cv::imread(image_file); - fastdeploy::vision::ClassifyResult res; - if (!model->Predict(im, &res)) { - std::cerr << "Failed to predict." << std::endl; - return; + fastdeploy::vision::ClassifyResult res; + if (!model->Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + // print res + std::cout << "Thread Id: " << thread_id << std::endl; + std::cout << res.Str() << std::endl; } - - // print res - std::cout << "Thread Id: " << thread_id << std::endl; - std::cout << res.Str() << std::endl; } -void CpuInfer(const std::string& model_dir, const std::string& image_file, int thread_num) { +void GetImageList(std::vector>* image_list, const std::string& image_file_path, int thread_num){ + std::vector images; + cv::glob(image_file_path, images, false); + // number of image files in images folder + size_t count = images.size(); + size_t num = count / thread_num; + for (int i = 0; i < thread_num; i++) { + std::vector temp_list; + if (i == thread_num - 1) { + for (size_t j = i*num; j < count; j++){ + temp_list.push_back(images[j]); + } + } else { + for (size_t j = 0; j < num; j++){ + temp_list.push_back(images[i * num + j]); + } + } + (*image_list)[i] = temp_list; + } +} + +void CpuInfer(const std::string& model_dir, const std::string& image_file_path, int thread_num) { auto model_file = model_dir + sep + "inference.pdmodel"; auto params_file = model_dir + sep + "inference.pdiparams"; auto config_file = model_dir + sep + "inference_cls.yaml"; @@ -39,9 +62,12 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file, int t models.emplace_back(std::move(model.Clone())); } + std::vector> image_list(thread_num); + GetImageList(&image_list, image_file_path, thread_num); + std::vector threads; for (int i = 0; i < thread_num; ++i) { - threads.emplace_back(predict, models[i].get(), i, image_file); + threads.emplace_back(Predict, models[i].get(), i, image_list[i]); } for (int i = 0; i < thread_num; ++i) { @@ -49,7 +75,7 @@ void CpuInfer(const std::string& model_dir, const std::string& image_file, int t } } -void GpuInfer(const std::string& model_dir, const std::string& image_file, int thread_num) { +void GpuInfer(const std::string& model_dir, const std::string& image_file_path, int thread_num) { auto model_file = model_dir + sep + "inference.pdmodel"; auto params_file = model_dir + sep + "inference.pdiparams"; auto config_file = model_dir + sep + "inference_cls.yaml"; @@ -68,9 +94,12 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file, int t models.emplace_back(std::move(model.Clone())); } + std::vector> image_list(thread_num); + GetImageList(&image_list, image_file_path, thread_num); + std::vector threads; for (int i = 0; i < thread_num; ++i) { - threads.emplace_back(predict, models[i].get(), i, image_file); + threads.emplace_back(Predict, models[i].get(), i, image_list[i]); } for (int i = 0; i < thread_num; ++i) { @@ -78,7 +107,7 @@ void GpuInfer(const std::string& model_dir, const std::string& image_file, int t } } -void TrtInfer(const std::string& model_dir, const std::string& image_file, int thread_num) { +void TrtInfer(const std::string& model_dir, const std::string& image_file_path, int thread_num) { auto model_file = model_dir + sep + "inference.pdmodel"; auto params_file = model_dir + sep + "inference.pdiparams"; auto config_file = model_dir + sep + "inference_cls.yaml"; @@ -99,9 +128,12 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file, int t models.emplace_back(std::move(model.Clone())); } + std::vector> image_list(thread_num); + GetImageList(&image_list, image_file_path, thread_num); + std::vector threads; for (int i = 0; i < thread_num; ++i) { - threads.emplace_back(predict, models[i].get(), i, image_file); + threads.emplace_back(Predict, models[i].get(), i, image_list[i]); } for (int i = 0; i < thread_num; ++i) { @@ -112,7 +144,7 @@ void TrtInfer(const std::string& model_dir, const std::string& image_file, int t int main(int argc, char **argv) { if (argc < 5) { std::cout << "Usage: infer_demo path/to/model path/to/image run_option thread_num, " - "e.g ./infer_demo ./ResNet50_vd ./test.jpeg 0 3" + "e.g ./multi_thread_demo ./ResNet50_vd ./test.jpeg 0 3" << std::endl; std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " "with gpu; 2: run with gpu and use tensorrt backend." diff --git a/tutorials/multi_thread/python/README.md b/tutorials/multi_thread/python/README.md new file mode 100644 index 000000000..9d17e6f65 --- /dev/null +++ b/tutorials/multi_thread/python/README.md @@ -0,0 +1,77 @@ +# PaddleClas模型 Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/classification/paddleclas/python + +# 下载ResNet50_vd模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz +tar -xvf ResNet50_vd_infer.tgz +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# CPU推理 +python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 +# GPU推理 +python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 +# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1 +# IPU推理(注意:IPU推理首次运行会有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device ipu --topk 1 +``` + +运行完成后返回结果如下所示 +```bash +ClassifyResult( +label_ids: 153, +scores: 0.686229, +) +``` + +## PaddleClasModel Python接口 + +```python +fd.vision.classification.PaddleClasModel(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE) +``` + +PaddleClas模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 推理部署配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + +### predict函数 + +> ```python +> PaddleClasModel.predict(input_image, topk=1) +> ``` +> +> 模型预测结口,输入图像直接输出分类topk结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + +> **返回** +> +> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [PaddleClas 模型介绍](..) +- [PaddleClas C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/tutorials/multi_thread.py b/tutorials/multi_thread/python/multi_thread_process.py similarity index 59% rename from tutorials/multi_thread.py rename to tutorials/multi_thread/python/multi_thread_process.py index 27d3b3331..edeee6a06 100644 --- a/tutorials/multi_thread.py +++ b/tutorials/multi_thread/python/multi_thread_process.py @@ -4,6 +4,7 @@ import fastdeploy as fd import cv2 import os import psutil +from multiprocessing import Pool def parse_arguments(): @@ -31,6 +32,13 @@ def parse_arguments(): default=False, help="Wether to use tensorrt.") parser.add_argument("--thread_num", type=int, default=1, help="thread num") + parser.add_argument( + "--use_multi_process", + type=ast.literal_eval, + default=False, + help="Wether to use multi process.") + parser.add_argument( + "--process_num", type=int, default=1, help="process num") return parser.parse_args() @@ -71,7 +79,7 @@ def build_option(args): def predict(model, img_list, topk): result_list = [] - # 预测图片分类结果 + # predict classification result for image in img_list: im = cv2.imread(image) result = model.predict(im, topk) @@ -79,6 +87,13 @@ def predict(model, img_list, topk): return result_list +def process_predict(image): + # predict classification result + im = cv2.imread(image) + result = model.predict(im, args.topk) + return result + + class WrapperThread(Thread): def __init__(self, func, args): super(WrapperThread, self).__init__() @@ -95,9 +110,8 @@ class WrapperThread(Thread): if __name__ == '__main__': args = parse_arguments() - thread_num = args.thread_num imgs_list = get_image_list(args.image_path) - # 配置runtime,加载模型 + # configure runtime and load model runtime_option = build_option(args) model_file = os.path.join(args.model, "inference.pdmodel") @@ -105,24 +119,38 @@ if __name__ == '__main__': config_file = os.path.join(args.model, "inference_cls.yaml") model = fd.vision.classification.PaddleClasModel( model_file, params_file, config_file, runtime_option=runtime_option) - threads = [] - image_num_each_thread = int(len(imgs_list) / thread_num) - for i in range(thread_num): - if i == thread_num - 1: - t = WrapperThread( - predict, - args=(model, imgs_list[i * image_num_each_thread:], i)) - else: - t = WrapperThread( - predict, - args=(model.clone(), imgs_list[i * image_num_each_thread:( - i + 1) * image_num_each_thread - 1], i)) - threads.append(t) - t.start() + if args.use_multi_process: + results = [] + process_num = args.process_num + with Pool(process_num) as pool: + results = pool.map(process_predict, imgs_list) + for result in results: + print(result) + else: + threads = [] + thread_num = args.thread_num + image_num_each_thread = int(len(imgs_list) / thread_num) + # unless you want independent model in each thread, actually model.clone() + # is the same as model when creating thead because of the existence of + # GIL(Global Interpreter Lock) in python. In addition, model.clone() will consume + # additional memory to store independent member variables + for i in range(thread_num): + if i == thread_num - 1: + t = WrapperThread( + predict, + args=(model.clone(), imgs_list[i * image_num_each_thread:], + args.topk)) + else: + t = WrapperThread( + predict, + args=(model.clone(), imgs_list[i * image_num_each_thread:( + i + 1) * image_num_each_thread - 1], args.topk)) + threads.append(t) + t.start() - for i in range(thread_num): - threads[i].join() + for i in range(thread_num): + threads[i].join() - for i in range(thread_num): - for result in threads[i].get_result(): - print('thread:', i, ', result: ', result) + for i in range(thread_num): + for result in threads[i].get_result(): + print('thread:', i, ', result: ', result)