Add PaddleSeg doc and infer.cc demo (#114)

* Update README.md

* Update README.md

* Update README.md

* Create README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Add evaluation calculate time and fix some bugs

* Update classification __init__

* Move to ppseg

* Add segmentation doc

* Add PaddleClas infer.py

* Update PaddleClas infer.py

* Delete .infer.py.swp

* Add PaddleClas infer.cc

* Update README.md

* Update README.md

* Update README.md

* Update infer.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Add PaddleSeg doc and infer.cc demo

* Update README.md

* Update README.md

* Update README.md

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
huangjianhui
2022-08-15 15:24:38 +08:00
committed by GitHub
parent 773d6bb938
commit a016ef99ce
10 changed files with 159 additions and 150 deletions

View File

@@ -14,7 +14,7 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
const Frontend& model_format) { const Frontend& model_format) {
config_file_ = config_file; config_file_ = config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
runtime_option.model_file = model_file; runtime_option.model_file = model_file;

View File

@@ -21,7 +21,7 @@
PaddleClas模型导出请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) PaddleClas模型导出请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA)
注意PaddleClas导出的模型仅包含`inference.pdmodel``inference.pdiparams`两个文,但为了满足部署的需求,同时也需准备其提供的通用[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数具体可比照PaddleClas模型训练[config](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4/ppcls/configs/ImageNet)中的infer部分的配置信息进行修改。 注意PaddleClas导出的模型仅包含`inference.pdmodel``inference.pdiparams`两个文,但为了满足部署的需求,同时也需准备其提供的通用[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数具体可比照PaddleClas模型训练[config](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4/ppcls/configs/ImageNet)中的infer部分的配置信息进行修改。
## 下载预训练模型 ## 下载预训练模型

View File

@@ -8,22 +8,21 @@
本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
``` ```
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/classification/paddleclas/python
# 下载ResNet50_vd模型文件和测试图片 # 下载ResNet50_vd模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz tar -xvf ResNet50_vd_infer.tgz
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/classification/paddleclas/python
# CPU推理 # CPU推理
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device cpu python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1
# GPU推理 # GPU推理
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1
# GPU上使用TensorRT推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待 # GPU上使用TensorRT推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True --topk 1
``` ```
运行完成后返回结果如下所示 运行完成后返回结果如下所示

View File

@@ -15,7 +15,7 @@ wget https://bj.bcebos.com/paddlehub/fastdeploy/libs/0.2.0/fastdeploy-linux-x64-
tar xvf fastdeploy-linux-x64-gpu-0.2.0.tgz tar xvf fastdeploy-linux-x64-gpu-0.2.0.tgz
cd fastdeploy-linux-x64-gpu-0.2.0/examples/vision/detection/paddledetection cd fastdeploy-linux-x64-gpu-0.2.0/examples/vision/detection/paddledetection
mkdir build && cd build mkdir build && cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../fastdeploy-linux-x64-gpu-0.2.0 cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.2.0
make -j make -j
# 下载PPYOLOE模型文件和测试图片 # 下载PPYOLOE模型文件和测试图片

View File

@@ -1,54 +1,36 @@
# PaddleClas 模型部署 # PaddleSeg 模型部署
## 模型版本说明 ## 模型版本说明
- [PaddleClas Release/2.4](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4) - [PaddleSeg Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6)
目前FastDeploy支持如下模型的部署 目前FastDeploy支持如下模型的部署
- [PP-LCNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNet.md) - [U-Net系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/unet/README.md)
- [PP-LCNetV2系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNetV2.md) - [PP-LiteSeg系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/pp_liteseg/README.md)
- [EfficientNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/EfficientNet_and_ResNeXt101_wsl.md) - [PP-HumanSeg系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/contrib/PP-HumanSeg/README.md)
- [GhostNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md) - [FCN系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/fcn/README.md)
- [MobileNet系列模型(包含v1,v2,v3)](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md) - [DeepLabV3系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/deeplabv3/README.md)
- [ShuffleNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md)
- [SqueezeNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Others.md)
- [Inception系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Inception.md)
- [PP-HGNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-HGNet.md)
- [ResNet系列模型包含vd系列](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ResNet_and_vd.md)
## 准备PaddleClas部署模型 ## 准备PaddleSeg部署模型
PaddleClas模型导出,请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) PaddleSeg模型导出,请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/docs/model_export_cn.md)
注意PaddleClas导出的模型包含`inference.pdmodel``inference.pdiparams`两个文档,但为了满足部署的需求,同时也需准备其提供的通用[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数具体可比照PaddleClas模型训练[config](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4/ppcls/configs/ImageNet)中的infer部分的配置信息进行修改 注意:在使用PaddleSeg模型导出时可指定`--input_shape`参数若预测输入图片尺寸并不固定建议使用默认值即不指定该参数。PaddleSeg导出的模型包含`model.pdmodel``model.pdiparams``deploy.yaml`三个文件FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息。
## 下载预训练模型 ## 下载预训练模型
为了方便开发者的测试下面提供了PaddleClas导出的部分模型(含inference_cls.yaml文件),开发者可直接下载使用。 为了方便开发者的测试下面提供了PaddleSeg导出的部分模型(导出方式为:**不指定**`input_shape``with_softmax`**指定**`without_argmax`),开发者可直接下载使用。
| 模型 | 参数文件大小 |输入Shape | Top1 | Top5 | | 模型 | 参数文件大小 |输入Shape | mIoU | mIoU (flip) | mIoU (ms+flip) |
|:---------------------------------------------------------------- |:----- |:----- | :----- | :----- | |:---------------------------------------------------------------- |:----- |:----- | :----- | :----- | :----- |
| [PPLCNet_x1_0](https://bj.bcebos.com/paddlehub/fastdeploy/PPLCNet_x1_0_infer.tgz) | 12MB | 224x224 |71.32% | 90.03% | | [Unet-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz) | 52MB | 1024x512 | 65.00% | 66.02% | 66.89% |
| [PPLCNetV2_base](https://bj.bcebos.com/paddlehub/fastdeploy/PPLCNetV2_base_infer.tgz) | 26MB | 224x224 |77.04% | 93.27% | | [PP-LiteSeg-T(STDC1)-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_infer.tgz) | 31MB | 1024x512 |73.10% | 73.89% | - |
| [EfficientNetB7](https://bj.bcebos.com/paddlehub/fastdeploy/EfficientNetB7_infer.tgz) | 255MB | 600x600 | 84.3% | 96.9% | | [PP-HumanSegV1-Lite](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - |
| [EfficientNetB0_small](https://bj.bcebos.com/paddlehub/fastdeploy/EfficientNetB0_small_infer.tgz)| 18MB | 224x224 | 75.8% | 75.8% | | [PP-HumanSegV1-Server](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - |
| [GhostNet_x1_3_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/GhostNet_x1_3_ssld_infer.tgz) | 29MB | 224x224 | 75.7% | 92.5% | | [FCN-HRNet-W18-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/FCN_HRNet_W18_cityscapes_without_argmax_infer.tgz) | 37MB | 1024x512 | 78.97% | 79.49% | 79.74% |
| [GhostNet_x0_5_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/GhostNet_x0_5_infer.tgz) | 10MB | 224x224 | 66.8% | 86.9% | | [Deeplabv3-ResNet50-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet50_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% |
| [MobileNetV1_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV1_x0_25_infer.tgz) | 1.9MB | 224x224 | 51.4% | 75.5% |
| [MobileNetV1_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV1_ssld_infer.tgz) | 17MB | 224x224 | 77.9% | 93.9% |
| [MobileNetV2_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV2_x0_25_infer.tgz) | 5.9MB | 224x224 | 53.2% | 76.5% |
| [MobileNetV2_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV2_ssld_infer.tgz) | 14MB | 224x224 | 76.74% | 93.39% |
| [MobileNetV3_small_x0_35_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV3_small_x0_35_ssld_infer.tgz) | 6.4MB | 224x224 | 55.55% | 77.71% |
| [MobileNetV3_large_x1_0_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV3_large_x1_0_ssld_infer.tgz) | 22MB | 224x224 | 78.96% | 94.48% |
| [ShuffleNetV2_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/ShuffleNetV2_x0_25_infer.tgz) | 2.4MB | 224x224 | 49.9% | 73.79% |
| [ShuffleNetV2_x2_0](https://bj.bcebos.com/paddlehub/fastdeploy/ShuffleNetV2_x2_0_infer.tgz) | 29MB | 224x224 | 73.15% | 91.2% |
| [SqueezeNet1_1](https://bj.bcebos.com/paddlehub/fastdeploy/SqueezeNet1_1_infer.tgz) | 4.8MB | 224x224 | 60.1% | 81.9% |
| [InceptionV3](https://bj.bcebos.com/paddlehub/fastdeploy/InceptionV3_infer.tgz) | 92MB | 299x299 | 79.14% | 94.59% |
| [PPHGNet_tiny_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/PPHGNet_tiny_ssld_infer.tgz) | 57MB | 224x224 | 81.95% | 96.12% |
| [PPHGNet_base_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/PPHGNet_base_ssld_infer.tgz) | 274MB | 224x224 | 85.0% | 97.35% |
| [ResNet50_vd](https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz) | 98MB | 224x224 | 79.12% | 94.44% |
## 详细部署文档 ## 详细部署文档

View File

@@ -1,6 +1,6 @@
# YOLOv7 C++部署示例 # PaddleSeg C++部署示例
本目录下提供`infer.cc`快速完成YOLOv7在CPU/GPU以及GPU上通过TensorRT加速部署的示例。 本目录下提供`infer.cc`快速完成Unet在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤 在部署前,需确认以下两个步骤
@@ -12,51 +12,58 @@
``` ```
mkdir build mkdir build
cd build cd build
wget https://xxx.tgz wget https://bj.bcebos.com/paddlehub/fastdeploy/libs/0.2.0/fastdeploy-linux-x64-gpu-0.2.0.tgz
tar xvf fastdeploy-linux-x64-0.2.0.tgz tar xvf fastdeploy-linux-x64-gpu-0.2.0.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0 cd fastdeploy-linux-x64-gpu-0.2.0/examples/vision/segmentation/paddleseg/cpp/build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.2.0
make -j make -j
#下载官方转换好的yolov7模型文件和测试图片 # 下载Unet模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx wget https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000087038.jpg tar -xvf Unet_cityscapes_without_argmax_infer.tgz
wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png
# CPU推理 # CPU推理
./infer_demo yolov7.onnx 000000087038.jpg 0 ./infer_demo Unet_cityscapes_without_argmax_infer infer.cc cityscapes_demo.png 0
# GPU推理 # GPU推理
./infer_demo yolov7.onnx 000000087038.jpg 1 ./infer_demo Unet_cityscapes_without_argmax_infer infer.cc cityscapes_demo.png 1
# GPU上TensorRT推理 # GPU上TensorRT推理
./infer_demo yolov7.onnx 000000087038.jpg 2 ./infer_demo Unet_cityscapes_without_argmax_infer infer.cc cityscapes_demo.png 2
``` ```
## YOLOv7 C++接口 运行完成可视化结果如下图所示
<div align="center">
<img src="https://user-images.githubusercontent.com/16222477/184588768-45ee673b-ef1f-40f4-9fbd-6b1a9ce17c59.png", width=512px, height=256px />
</div>
### YOLOv7类 ## PaddleSeg C++接口
### PaddleSeg类
``` ```
fastdeploy::vision::detection::YOLOv7( fastdeploy::vision::segmentation::PaddleSegModel(
const string& model_file, const string& model_file,
const string& params_file = "", const string& params_file = "",
const string& config_file,
const RuntimeOption& runtime_option = RuntimeOption(), const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX) const Frontend& model_format = Frontend::PADDLE)
``` ```
YOLOv7模型加载和初始化其中model_file为导出的ONNX模型格式。 PaddleSegModel模型加载和初始化其中model_file为导出的Paddle模型格式。
**参数** **参数**
> * **model_file**(str): 模型文件路径 > * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可 > * **params_file**(str): 参数文件路径
> * **config_file**(str): 推理部署配置文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置 > * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式,默认为ONNX格式 > * **model_format**(Frontend): 模型格式,默认为Paddle格式
#### Predict函数 #### Predict函数
> ``` > ```
> YOLOv7::Predict(cv::Mat* im, DetectionResult* result, > PaddleSegModel::Predict(cv::Mat* im, DetectionResult* result)
> float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5)
> ``` > ```
> >
> 模型预测接口,输入图像直接输出检测结果。 > 模型预测接口,输入图像直接输出检测结果。
@@ -64,13 +71,16 @@ YOLOv7模型加载和初始化其中model_file为导出的ONNX模型格式。
> **参数** > **参数**
> >
> > * **im**: 输入图像注意需为HWCBGR格式 > > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) > > * **result**: 分割结果,包括分割预测的标签以及标签对应的概率值, SegmentationResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
> > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值
### 类成员变量 ### 类成员属性
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **size**(vector<int>): 通过此参数修改预处理过程中resize的大小包含两个整型元素表示[width, height], 默认值为[640, 640] > > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`True`表明输入图片是竖屏即height大于width的图片
#### 后处理参数
> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`True`将预测的输出分割标签label_map对应的概率结果(score_map)做softmax归一化处理
- [模型介绍](../../) - [模型介绍](../../)
- [Python部署](../python) - [Python部署](../python)

View File

@@ -14,34 +14,45 @@
#include "fastdeploy/vision.h" #include "fastdeploy/vision.h"
void CpuInfer(const std::string& model_file, const std::string& params_file, #ifdef WIN32
const std::string& config_file, const std::string& image_file) { const char sep = '\\';
auto option = fastdeploy::RuntimeOption(); #else
option.UseCpu() auto model = const char sep = '/';
fastdeploy::vision::classification::PaddleClasModel( #endif
model_file, params_file, config_file, option);
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "deploy.yaml";
auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file);
if (!model.Initialized()) { if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl; std::cerr << "Failed to initialize." << std::endl;
return; return;
} }
auto im = cv::imread(image_file); auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res; fastdeploy::vision::SegmentationResult res;
if (!model.Predict(&im, &res)) { if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl; std::cerr << "Failed to predict." << std::endl;
return; return;
} }
// print res auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
res.Str(); cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
} }
void GpuInfer(const std::string& model_file, const std::string& params_file, void GpuInfer(const std::string& model_dir, const std::string& image_file) {
const std::string& config_file, const std::string& image_file) { auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "deploy.yaml";
auto option = fastdeploy::RuntimeOption(); auto option = fastdeploy::RuntimeOption();
option.UseGpu(); option.UseGpu();
auto model = fastdeploy::vision::classification::PaddleClasModel( auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file, option); model_file, params_file, config_file, option);
if (!model.Initialized()) { if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl; std::cerr << "Failed to initialize." << std::endl;
@@ -49,25 +60,30 @@ void GpuInfer(const std::string& model_file, const std::string& params_file,
} }
auto im = cv::imread(image_file); auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res; fastdeploy::vision::SegmentationResult res;
if (!model.Predict(&im, &res)) { if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl; std::cerr << "Failed to predict." << std::endl;
return; return;
} }
// print res auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
res.Str(); cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
} }
void TrtInfer(const std::string& model_file, const std::string& params_file, void TrtInfer(const std::string& model_dir, const std::string& image_file) {
const std::string& config_file, const std::string& image_file) { auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "deploy.yaml";
auto option = fastdeploy::RuntimeOption(); auto option = fastdeploy::RuntimeOption();
option.UseGpu(); option.UseGpu();
option.UseTrtBackend(); option.UseTrtBackend();
option.SetTrtInputShape("inputs", [ 1, 3, 224, 224 ], [ 1, 3, 224, 224 ], option.SetTrtInputShape("x", {1, 3, 256, 256}, {1, 3, 1024, 1024},
[ 1, 3, 224, 224 ]); {1, 3, 2048, 2048});
auto model = fastdeploy::vision::classification::PaddleClasModel( auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file, option); model_file, params_file, config_file, option);
if (!model.Initialized()) { if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl; std::cerr << "Failed to initialize." << std::endl;
@@ -75,40 +91,37 @@ void TrtInfer(const std::string& model_file, const std::string& params_file,
} }
auto im = cv::imread(image_file); auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res; fastdeploy::vision::SegmentationResult res;
if (!model.Predict(&im, &res)) { if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl; std::cerr << "Failed to predict." << std::endl;
return; return;
} }
// print res auto vis_im = fastdeploy::vision::Visualize::VisSegmentation(im_bak, res);
res.Str(); cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc < 4) { if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " std::cout
"e.g ./infer_demo ./ResNet50_vd ./test.jpeg 0" << "Usage: infer_demo path/to/model_dir path/to/image run_option, "
<< std::endl; "e.g ./infer_model ./ppseg_model_dir ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend." "with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl; << std::endl;
return -1; return -1;
} }
std::string model_file = if (std::atoi(argv[3]) == 0) {
argv[1] + "/" + "model.pdmodel" std::string params_file = CpuInfer(argv[1], argv[2]);
argv[1] + "/" + "model.pdiparams" std::string config_file = } else if (std::atoi(argv[3]) == 1) {
argv[1] + "/" + "inference_cls.yaml" std::string image_file = GpuInfer(argv[1], argv[2]);
argv[2] if (std::atoi(argv[3]) == 0) { } else if (std::atoi(argv[3]) == 2) {
CpuInfer(model_file, params_file, config_file, image_file); TrtInfer(argv[1], argv[2]);
}
else if (std::atoi(argv[3]) == 1) {
GpuInfer(model_file, params_file, config_file, image_file);
}
else if (std::atoi(argv[3]) == 2) {
TrtInfer(model_file, params_file, config_file, image_file);
} }
return 0; return 0;
} }

View File

@@ -1,46 +1,43 @@
# PaddleClas模型 Python部署示例 # PaddleSeg Python部署示例
在部署前,需确认以下两个步骤 在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md) - 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md) - 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md)
本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 本目录下提供`infer.py`快速完成Unet在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
``` ```
# 下载ResNet50_vd模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
#下载部署示例代码 #下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/classification/paddleclas/python cd FastDeploy/examples/vision/segmentation/paddleseg/python
# 下载Unet模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz
tar -xvf Unet_cityscapes_without_argmax_infer.tgz
wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png
# CPU推理 # CPU推理
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device cpu python infer.py --model Unet_cityscapes_without_argmax_infer --image cityscapes_demo.png --device cpu
# GPU推理 # GPU推理
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu python infer.py --model Unet_cityscapes_without_argmax_infer --image cityscapes_demo.png --device gpu
# GPU上使用TensorRT推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待 # GPU上使用TensorRT推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True python infer.py --model Unet_cityscapes_without_argmax_infer --image cityscapes_demo.png --device gpu --use_trt True
``` ```
运行完成后返回结果如下所示 运行完成可视化结果如下所示
``` <div align="center">
ClassifyResult( <img src="https://user-images.githubusercontent.com/16222477/184588768-45ee673b-ef1f-40f4-9fbd-6b1a9ce17c59.png", width=512px, height=256px />
label_ids: 153, </div>
scores: 0.686229,
)
```
## PaddleClasModel Python接口 ## PaddleSegModel Python接口
``` ```
fd.vision.classification.PaddleClasModel(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) fd.vision.segmentation.PaddleSegModel(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
``` ```
PaddleClas模型加载和初始化其中model_file, params_file为训练模型导出的Paddle inference文件具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) PaddleSeg模型加载和初始化其中model_file, params_file以及config_file为训练模型导出的Paddle inference文件具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/docs/model_export_cn.md)
**参数** **参数**
@@ -53,7 +50,7 @@ PaddleClas模型加载和初始化其中model_file, params_file为训练模
### predict函数 ### predict函数
> ``` > ```
> PaddleClasModel.predict(input_image, topk=1) > PaddleSegModel.predict(input_image)
> ``` > ```
> >
> 模型预测结口,输入图像直接输出检测结果。 > 模型预测结口,输入图像直接输出检测结果。
@@ -61,15 +58,22 @@ PaddleClas模型加载和初始化其中model_file, params_file为训练模
> **参数** > **参数**
> >
> > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式 > > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **topk**(int):返回预测概率最高的topk个分类结果
> **返回** > **返回**
> >
> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) > > 返回`fastdeploy.vision.SegmentationResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
### 类成员属性
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **is_vertical_screen**(bool): PP-HumanSeg系列模型通过设置此参数为`true`表明输入图片是竖屏即height大于width的图片
#### 后处理参数
> > * **with_softmax**(bool): 当模型导出时,并未指定`with_softmax`参数,可通过此设置此参数为`true`将预测的输出分割标签label_map对应的概率结果(score_map)做softmax归一化处理
## 其它文档 ## 其它文档
- [PaddleClas 模型介绍](..) - [PaddleSeg 模型介绍](..)
- [PaddleClas C++部署](../cpp) - [PaddleSeg C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/) - [模型预测结果说明](../../../../../docs/api/vision_results/)

View File

@@ -8,11 +8,9 @@ def parse_arguments():
import ast import ast
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model", required=True, help="Path of PaddleClas model.") "--model", required=True, help="Path of PaddleSeg model.")
parser.add_argument( parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.") "--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--topk", type=int, default=1, help="Return topk results.")
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
@@ -43,14 +41,17 @@ args = parse_arguments()
# 配置runtime加载模型 # 配置runtime加载模型
runtime_option = build_option(args) runtime_option = build_option(args)
model_file = os.path.join(args.model, "inference.pdmodel") model_file = os.path.join(args.model, "model.pdmodel")
params_file = os.path.join(args.model, "inference.pdiparams") params_file = os.path.join(args.model, "model.pdiparams")
config_file = os.path.join(args.model, "inference_cls.yaml") config_file = os.path.join(args.model, "deploy.yaml")
#model = fd.vision.classification.PaddleClasModel(model_file, params_file, config_file, runtime_option=runtime_option) model = fd.vision.segmentation.PaddleSegModel(
model = fd.vision.classification.ResNet50vd(
model_file, params_file, config_file, runtime_option=runtime_option) model_file, params_file, config_file, runtime_option=runtime_option)
# 预测图片分类结果 # 预测图片分类结果
im = cv2.imread(args.image) im = cv2.imread(args.image)
result = model.predict(im, args.topk) result = model.predict(im)
print(result) print(result)
# 可视化结果
vis_im = fd.vision.visualize.vis_segmentation(im, result)
cv2.imwrite("vis_img.png", vis_im)

View File

@@ -23,9 +23,9 @@ class PaddleSegModel(FastDeployModel):
model_file, model_file,
params_file, params_file,
config_file, config_file,
backend_option=None, runtime_option=None,
model_format=Frontend.PADDLE): model_format=Frontend.PADDLE):
super(Model, self).__init__(backend_option) super(PaddleSegModel, self).__init__(runtime_option)
assert model_format == Frontend.PADDLE, "PaddleSeg only support model format of Frontend.Paddle now." assert model_format == Frontend.PADDLE, "PaddleSeg only support model format of Frontend.Paddle now."
self._model = C.vision.segmentation.PaddleSegModel( self._model = C.vision.segmentation.PaddleSegModel(