add ocr, ppyoloe, picodet examples (#1076)

* add ocr examples

* add ppyoloe examples

add picodet examples

* remove /ScaleFactor in ppdet/postprocessor.cc
This commit is contained in:
Dantès
2023-01-10 16:34:26 +08:00
committed by GitHub
parent fc314f1696
commit de70e8366c
21 changed files with 922 additions and 13 deletions

View File

@@ -0,0 +1,107 @@
# PaddleDetection SOPHGO部署示例
## 支持模型列表
目前SOPHGO支持如下模型的部署
- [PP-YOLOE系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
## 准备PP-YOLOE或者PicoDet部署模型以及转换模型
SOPHGO-TPU部署模型前需要将Paddle模型转换成bmodel模型具体步骤如下:
- Paddle动态图模型转换为ONNX模型请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md).
- ONNX模型转换bmodel模型的过程请参考[TPU-MLIR](https://github.com/sophgo/tpu-mlir)
## 模型转换example
PP-YOLOE和PicoDet模型转换过程类似下面以ppyoloe_crn_s_300e_coco为例子,教大家如何转换Paddle模型到SOPHGO-TPU模型
### 导出ONNX模型
```shell
#导出paddle模型
python tools/export_model.py -c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams
#paddle模型转ONNX模型
paddle2onnx --model_dir ppyoloe_crn_s_300e_coco \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--save_file ppyoloe_crn_s_300e_coco.onnx \
--enable_dev_version True
#进入Paddle2ONNX文件夹固定ONNX模型shape
python -m paddle2onnx.optimize --input_model ppyoloe_crn_s_300e_coco.onnx \
--output_model ppyoloe_crn_s_300e_coco.onnx \
--input_shape_dict "{'image':[1,3,640,640]}"
```
### 导出bmodel模型
以转化BM1684x的bmodel模型为例子我们需要下载[TPU-MLIR](https://github.com/sophgo/tpu-mlir)工程,安装过程具体参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
### 1. 安装
``` shell
docker pull sophgo/tpuc_dev:latest
# myname1234是一个示例也可以设置其他名字
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
source ./envsetup.sh
./build.sh
```
### 2. ONNX模型转换为bmodel模型
``` shell
mkdir ppyoloe_crn_s_300e_coco && cd ppyoloe_crn_s_300e_coco
# 下载测试图片并将图片转换为npz格式
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
#使用python获得模型转换所需要的npz文件
im = cv2.imread(im)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
#[640 640]为ppyoloe_crn_s_300e_coco的输入大小
im_scale_y = 640 / float(im.shape[0])
im_scale_x = 640 / float(im.shape[1])
inputs = {}
inputs['image'] = np.array((im, )).astype('float32')
inputs['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
np.savez('inputs.npz', image = inputs['image'], scale_factor = inputs['scale_factor'])
#放入onnx模型文件ppyoloe_crn_s_300e_coco.onnx
mkdir workspace && cd workspace
# 将ONNX模型转换为mlir模型
model_transform.py \
--model_name ppyoloe_crn_s_300e_coco \
--model_def ../ppyoloe_crn_s_300e_coco.onnx \
--input_shapes [[1,3,640,640],[1,2]] \
--keep_aspect_ratio \
--pixel_format rgb \
--output_names p2o.Div.1,p2o.Concat.29 \
--test_input ../inputs.npz \
--test_result ppyoloe_crn_s_300e_coco_top_outputs.npz \
--mlir ppyoloe_crn_s_300e_coco.mlir
```
### 注意
**由于TPU-MLIR当前不支持后处理算法所以需要查看后处理的输入作为网络的输出**
具体方法为output_names需要通过[NETRO](https://netron.app/) 查看网页中打开需要转换的ONNX模型搜索NonMaxSuppression节点
查看INPUTS中boxes和scores的名字这个两个名字就是我们所需的output_names
例如使用Netron可视化后可以得到如下图片
![](https://user-images.githubusercontent.com/120167928/210939488-a37e6c8b-474c-4948-8362-2066ee7a2ecb.png)
找到蓝色方框标记的NonMaxSuppression节点可以看到红色方框标记的两个节点名称为p2o.Div.1,p2o.Concat.29
``` bash
# 将mlir模型转换为BM1684x的F32 bmodel模型
model_deploy.py \
--mlir ppyoloe_crn_s_300e_coco.mlir \
--quantize F32 \
--chip bm1684x \
--test_input ppyoloe_crn_s_300e_coco_in_f32.npz \
--test_reference ppyoloe_crn_s_300e_coco_top_outputs.npz \
--model ppyoloe_crn_s_300e_coco_1684x_f32.bmodel
```
最终获得可以在BM1684x上能够运行的bmodel模型ppyoloe_crn_s_300e_coco_1684x_f32.bmodel。如果需要进一步对模型进行加速可以将ONNX模型转换为INT8 bmodel具体步骤参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
## 其他链接
- [Cpp部署](./cpp)
- [python部署](./python)

View File

@@ -0,0 +1,19 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
set(ENABLE_LITE_BACKEND OFF)
#set(FDLIB ${FASTDEPLOY_INSTALL_DIR})
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
include_directories(${FastDeploy_INCLUDE_DIRS})
add_executable(infer_ppyoloe ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
add_executable(infer_picodet ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_ppyoloe ${FASTDEPLOY_LIBS})
target_link_libraries(infer_picodet ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,61 @@
# PaddleDetection C++部署示例
本目录下提供`infer_ppyoloe.cc``infer_picodet.cc`快速完成PP-YOLOE模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
在部署前,需确认以下两个步骤:
1. 软硬件环境满足要求
2. 根据开发环境从头编译FastDeploy仓库
以上步骤请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)实现
## 生成基本目录文件
该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_ppyoloe.cc
├── infer_picodet.cc
└── model # 存放模型文件的文件夹
```
## 编译
### 编译并拷贝SDK到thirdpartys文件夹
请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)仓库编译SDK编译完成后将在build目录下生成fastdeploy-0.0.3目录.
### 拷贝模型文件以及配置文件至model文件夹
将Paddle模型转换为SOPHGO bmodel模型转换步骤参考[文档](../README.md)
将转换后的SOPHGO bmodel模型文件拷贝至model中
### 准备测试图片至image文件夹
```bash
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
cp 000000014439.jpg ./images
```
### 编译example
```bash
cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-0.0.3
make
```
## 运行例程
```bash
#ppyoloe推理示例
./infer_ppyoloe model images/000000014439.jpg
#picodet推理示例
./infer_picodet model images/000000014439.jpg
```
- [模型介绍](../../)
- [模型转换](../)

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/time.h>
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/picodet_s_416_coco_lcnet_1684x_f32.bmodel";
auto params_file = "";
auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseSophgo();
auto format = fastdeploy::ModelFormat::SOPHGO;
auto model = fastdeploy::vision::detection::PicoDet(
model_file, params_file, config_file, option, format);
model.GetPostprocessor().ApplyDecodeAndNMS();
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
cv::imwrite("infer_sophgo.jpg", vis_im);
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
<< std::endl;
return -1;
}
SophgoInfer(argv[1], argv[2]);
return 0;
}

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/time.h>
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel";
auto params_file = "";
auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseSophgo();
auto format = fastdeploy::ModelFormat::SOPHGO;
auto model = fastdeploy::vision::detection::PPYOLOE(
model_file, params_file, config_file, option, format);
model.GetPostprocessor().ApplyDecodeAndNMS();
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
cv::imwrite("infer_sophgo.jpg", vis_im);
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
<< std::endl;
return -1;
}
SophgoInfer(argv[1], argv[2]);
return 0;
}

View File

@@ -0,0 +1,32 @@
# PaddleDetection Python部署示例
在部署前,需确认以下步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/sophgo.md)
本目录下提供`infer_ppyoloe.py``infer_picodet.py`快速完成 PP-YOLOE 和 PicoDet 在SOPHGO TPU上部署的示例。执行如下脚本即可完成
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/detection/paddledetection/sophgo/python
# 下载图片
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# 推理
#ppyoloe推理示例
python3 infer_ppyoloe.py --model_file model/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
#picodet推理示例
python3 infer_picodet.py --model_file model/picodet_s_416_coco_lcnet_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
# 运行完成后返回结果如下所示
可视化结果存储在sophgo_result.jpg中
```
## 其它文档
- [PP-YOLOE C++部署](../cpp)
- [PicoDet C++部署](../cpp)
- [转换PicoDet SOPHGO模型文档](../README.md)
- [转换PP-YOLOE SOPHGO模型文档](../README.md)

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of sophgo model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
model_file = args.model_file
params_file = ""
config_file = args.config_file
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_sophgo()
model = fd.vision.detection.PicoDet(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.SOPHGO)
model.postprocessor.apply_decode_and_nms()
# 预测图片分割结果
im = cv2.imread(args.image)
result = model.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("sophgo_result.jpg", vis_im)
print("Visualized result save in ./sophgo_result.jpg")

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of sophgo model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
model_file = args.model_file
params_file = ""
config_file = args.config_file
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_sophgo()
model = fd.vision.detection.PPYOLOE(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.SOPHGO)
model.postprocessor.apply_decode_and_nms()
# 预测图片分割结果
im = cv2.imread(args.image)
result = model.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("sophgo_result.jpg", vis_im)
print("Visualized result save in ./sophgo_result.jpg")

View File

@@ -0,0 +1,88 @@
# PPOCRv3 SOPHGO C++部署示例
## 支持模型列表
- PP-OCRv3部署模型实现来自[PP-OCR系列模型列表](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md)
## 准备PPOCRv3部署模型以及转换模型
PPOCRv3包括文本框检测模型ch_PP-OCRv3_det、方向分类模型ch_ppocr_mobile_v2.0_cls、文字识别模型ch_PP-OCRv3_rec
SOPHGO-TPU部署模型前需要将以上Paddle模型转换成bmodel模型我们以ch_PP-OCRv3_det模型为例具体步骤如下:
- 下载Paddle模型[ch_PP-OCRv3_det](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar)
- Pddle模型转换为ONNX模型请参考[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)
- ONNX模型转换bmodel模型的过程请参考[TPU-MLIR](https://github.com/sophgo/tpu-mlir)
## 模型转换example
### 下载ch_PP-OCRv3_det模型,并转换为ONNX模型
```shell
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar
# 修改ch_PP-OCRv3_det模型的输入shape由动态输入变成固定输入
python paddle_infer_shape.py --model_dir ch_PP-OCRv3_det_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_dir ch_PP-OCRv3_det_infer_fix \
--input_shape_dict="{'x':[1,3,960,608]}"
#将固定输入的Paddle模型转换成ONNX模型
paddle2onnx --model_dir ch_PP-OCRv3_det_infer_fix \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ch_PP-OCRv3_det_infer_fix.onnx \
--enable_dev_version True
```
### 导出bmodel模型
以转换BM1684x的bmodel模型为例子我们需要下载[TPU-MLIR](https://github.com/sophgo/tpu-mlir)工程,安装过程具体参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
### 1. 安装
``` shell
docker pull sophgo/tpuc_dev:latest
# myname1234是一个示例也可以设置其他名字
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
source ./envsetup.sh
./build.sh
```
### 2. ONNX模型转换为bmodel模型
``` shell
mkdir ch_PP-OCRv3_det && cd ch_PP-OCRv3_det
#在该文件中放入测试图片同时将上一步转换的ch_PP-OCRv3_det_infer_fix.onnx放入该文件夹中
cp -rf ${REGRESSION_PATH}/dataset/COCO2017 .
cp -rf ${REGRESSION_PATH}/image .
#放入onnx模型文件ch_PP-OCRv3_det_infer_fix.onnx
mkdir workspace && cd workspace
#将ONNX模型转换为mlir模型其中参数--output_names可以通过NETRON查看
model_transform.py \
--model_name ch_PP-OCRv3_det \
--model_def ../ch_PP-OCRv3_det_infer_fix.onnx \
--input_shapes [[1,3,960,608]] \
--mean 0.0,0.0,0.0 \
--scale 0.0039216,0.0039216,0.0039216 \
--keep_aspect_ratio \
--pixel_format rgb \
--output_names sigmoid_0.tmp_0 \
--test_input ../image/dog.jpg \
--test_result ch_PP-OCRv3_det_top_outputs.npz \
--mlir ch_PP-OCRv3_det.mlir
#将mlir模型转换为BM1684x的F32 bmodel模型
model_deploy.py \
--mlir ch_PP-OCRv3_det.mlir \
--quantize F32 \
--chip bm1684x \
--test_input ch_PP-OCRv3_det_in_f32.npz \
--test_reference ch_PP-OCRv3_det_top_outputs.npz \
--model ch_PP-OCRv3_det_1684x_f32.bmodel
```
最终获得可以在BM1684x上能够运行的bmodel模型ch_PP-OCRv3_det_1684x_f32.bmodel。按照上面同样的方法可以将ch_ppocr_mobile_v2.0_clsch_PP-OCRv3_rec转换为bmodel的格式。如果需要进一步对模型进行加速可以将ONNX模型转换为INT8 bmodel具体步骤参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
## 其他链接
- [Cpp部署](./cpp)

View File

@@ -0,0 +1,13 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,58 @@
# PPOCRv3 C++部署示例
本目录下提供`infer.cc`快速完成PPOCRv3模型在SOPHGO BM1684x板子上加速部署的示例。
在部署前,需确认以下两个步骤:
1. 软硬件环境满足要求
2. 根据开发环境从头编译FastDeploy仓库
以上步骤请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)实现
## 生成基本目录文件
该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer.cc
└── model # 存放模型文件的文件夹
```
## 编译
### 编译并拷贝SDK到thirdpartys文件夹
请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)仓库编译SDK编译完成后将在build目录下生成fastdeploy-0.0.3目录.
### 拷贝bmodel模型文至model文件夹
将Paddle模型转换为SOPHGO bmodel模型转换步骤参考[文档](../README.md)
将转换后的SOPHGO bmodel模型文件拷贝至model中
### 准备测试图片至image文件夹以及字典文件
```bash
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
cp 12.jpg image/
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
```
### 编译example
```bash
cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-0.0.3
make
```
## 运行例程
```bash
./infer_demo model ./ppocr_keys_v1.txt image/12.jpeg
```
- [模型介绍](../../../)
- [模型转换](../)

View File

@@ -0,0 +1,136 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string& det_model_dir,
const std::string& rec_label_file,
const std::string& image_file,
const fastdeploy::RuntimeOption& option) {
auto det_model_file =
det_model_dir + sep + "ch_PP-OCRv3_det_1684x_f32.bmodel";
auto det_params_file = det_model_dir + sep + "";
auto cls_model_file =
det_model_dir + sep + "ch_ppocr_mobile_v2.0_cls_1684x_f32.bmodel";
auto cls_params_file = det_model_dir + sep + "";
auto rec_model_file =
det_model_dir + sep + "ch_PP-OCRv3_rec_1684x_f32.bmodel";
auto rec_params_file = det_model_dir + sep + "";
auto format = fastdeploy::ModelFormat::SOPHGO;
auto det_option = option;
auto cls_option = option;
auto rec_option = option;
// The cls and rec model can inference a batch of images now.
// User could initialize the inference batch size and set them after create
// PPOCR model.
int cls_batch_size = 1;
int rec_batch_size = 1;
// If use TRT backend, the dynamic shape will be set as follow.
// We recommend that users set the length and height of the detection model to
// a multiple of 32. We also recommend that users set the Trt input shape as
// follow.
det_option.SetTrtInputShape("x", {1, 3, 64, 64}, {1, 3, 640, 640},
{1, 3, 960, 960});
cls_option.SetTrtInputShape("x", {1, 3, 48, 10}, {cls_batch_size, 3, 48, 320},
{cls_batch_size, 3, 48, 1024});
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {rec_batch_size, 3, 48, 320},
{rec_batch_size, 3, 48, 2304});
// Users could save TRT cache file to disk as follow.
// det_option.SetTrtCacheFile(det_model_dir + sep + "det_trt_cache.trt");
// cls_option.SetTrtCacheFile(cls_model_dir + sep + "cls_trt_cache.trt");
// rec_option.SetTrtCacheFile(rec_model_dir + sep + "rec_trt_cache.trt");
auto det_model = fastdeploy::vision::ocr::DBDetector(
det_model_file, det_params_file, det_option, format);
auto cls_model = fastdeploy::vision::ocr::Classifier(
cls_model_file, cls_params_file, cls_option, format);
auto rec_model = fastdeploy::vision::ocr::Recognizer(
rec_model_file, rec_params_file, rec_label_file, rec_option, format);
// Users could enable static shape infer for rec model when deploy PP-OCR on
// hardware which can not support dynamic shape infer well, like Huawei Ascend
// series.
rec_model.GetPreprocessor().SetStaticShapeInfer(true);
rec_model.GetPreprocessor().SetRecImageShape({3, 48, 584});
assert(det_model.Initialized());
assert(cls_model.Initialized());
assert(rec_model.Initialized());
// The classification model is optional, so the PP-OCR can also be connected
// in series as follows auto ppocr_v3 =
// fastdeploy::pipeline::PPOCRv3(&det_model, &rec_model);
auto ppocr_v3 =
fastdeploy::pipeline::PPOCRv3(&det_model, &cls_model, &rec_model);
// Set inference batch size for cls model and rec model, the value could be -1
// and 1 to positive infinity. When inference batch size is set to -1, it
// means that the inference batch size of the cls and rec models will be the
// same as the number of boxes detected by the det model.
ppocr_v3.SetClsBatchSize(cls_batch_size);
ppocr_v3.SetRecBatchSize(rec_batch_size);
if (!ppocr_v3.Initialized()) {
std::cerr << "Failed to initialize PP-OCR." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::OCRResult result;
if (!ppocr_v3.Predict(&im, &result)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << result.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisOcr(im_bak, result);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model "
"path/to/rec_label_file path/to/image "
"e.g ./infer_demo ./ocr_bmodel "
"./ppocr_keys_v1.txt ./12.jpg"
<< std::endl;
return -1;
}
fastdeploy::RuntimeOption option;
option.UseSophgo();
option.UseSophgoBackend();
std::string model_dir = argv[1];
std::string rec_label_file = argv[2];
std::string test_image = argv[3];
InitAndInfer(model_dir, rec_label_file, test_image, option);
return 0;
}

View File

@@ -0,0 +1,38 @@
# PPOCRv3 Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/sophgo.md)
本目录下提供`infer.py`快速完成 PPOCRv3 在SOPHGO TPU上部署的示例。执行如下脚本即可完成
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/ocr/PP-OCRv3/sophgo/python
# 下载图片
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
#下载字典文件
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
# 推理
python3 infer.py --det_model ocr_bmodel/ch_PP-OCRv3_det_1684x_f32.bmodel \
--cls_model ocr_bmodel/ch_ppocr_mobile_v2.0_cls_1684x_f32.bmodel \
--rec_model ocr_bmodel/ch_PP-OCRv3_rec_1684x_f32.bmodel \
--rec_label_file ../ppocr_keys_v1.txt \
--image ../12.jpg
# 运行完成后返回结果如下所示
det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.952958 cls label: 0 cls score: 1.000000
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.897335 cls label: 0 cls score: 1.000000
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.994589 cls label: 0 cls score: 1.000000
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.900663 cls label: 0 cls score: 1.000000
可视化结果保存在sophgo_result.jpg中
```
## 其它文档
- [PPOCRv3 C++部署](../cpp)
- [转换 PPOCRv3 SOPHGO模型文档](../README.md)

View File

@@ -0,0 +1,116 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model", required=True, help="Path of Detection model of PPOCR.")
parser.add_argument(
"--cls_model",
required=True,
help="Path of Classification model of PPOCR.")
parser.add_argument(
"--rec_model",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--rec_label_file",
required=True,
help="Path of Recognization label of PPOCR.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
args = parse_arguments()
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_sophgo()
# Detection模型, 检测文字框
det_model_file = args.det_model
det_params_file = ""
# Classification模型方向分类可选
cls_model_file = args.cls_model
cls_params_file = ""
# Recognition模型文字识别模型
rec_model_file = args.rec_model
rec_params_file = ""
rec_label_file = args.rec_label_file
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
cls_batch_size = 1
rec_batch_size = 1
# 当使用TRT时分别给三个模型的runtime设置动态shape,并完成模型的创建.
# 注意: 需要在检测模型创建完成后,再设置分类模型的动态输入并创建分类模型, 识别模型同理.
# 如果用户想要自己改动检测模型的输入shape, 我们建议用户把检测模型的长和高设置为32的倍数.
det_option = runtime_option
det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
[1, 3, 960, 960])
# 用户可以把TRT引擎文件保存至本地
# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
det_model = fd.vision.ocr.DBDetector(
det_model_file,
det_params_file,
runtime_option=det_option,
model_format=fd.ModelFormat.SOPHGO)
cls_option = runtime_option
cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
[cls_batch_size, 3, 48, 320],
[cls_batch_size, 3, 48, 1024])
# 用户可以把TRT引擎文件保存至本地
# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
cls_model = fd.vision.ocr.Classifier(
cls_model_file,
cls_params_file,
runtime_option=cls_option,
model_format=fd.ModelFormat.SOPHGO)
rec_option = runtime_option
rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
[rec_batch_size, 3, 48, 320],
[rec_batch_size, 3, 48, 2304])
# 用户可以把TRT引擎文件保存至本地
# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_option,
model_format=fd.ModelFormat.SOPHGO)
# 创建PP-OCR串联3个模型其中cls_model可选如无需求可设置为None
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
# 需要使用下行代码, 来启用rec模型的静态shape推理这里rec模型的静态输入为[3, 48, 584]
rec_model.preprocessor.static_shape_infer = True
rec_model.preprocessor.rec_image_shape = [3, 48, 584]
# 给cls和rec模型设置推理时的batch size
# 此值能为-1, 和1到正无穷
# 当此值为-1时, cls和rec模型的batch size将默认和det模型检测出的框的数量相同
ppocr_v3.cls_batch_size = cls_batch_size
ppocr_v3.rec_batch_size = rec_batch_size
# 预测图片准备
im = cv2.imread(args.image)
#预测并打印结果
result = ppocr_v3.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_ppocr(im, result)
cv2.imwrite("sophgo_result.jpg", vis_im)
print("Visualized result save in ./sophgo_result.jpg")

View File

@@ -71,8 +71,8 @@ bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
}
} else if (use_sophgotpu) {
if (!IsSupported(valid_sophgonpu_backends, runtime_option.backend)) {
FDERROR << "The valid rknpu backends of model " << ModelName() << " are "
<< Str(valid_rknpu_backends) << ", " << runtime_option.backend
FDERROR << "The valid sophgo backends of model " << ModelName() << " are "
<< Str(valid_sophgonpu_backends) << ", " << runtime_option.backend
<< " is not supported." << std::endl;
return false;
}

View File

@@ -180,7 +180,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
assert(BM_SUCCESS == status);
input_tensors[i].dtype = input_dtypes[i];
input_tensors[i].st_mode = BM_STORE_1N;
input_tensors[i].shape = *(net_info_->stages[i].input_shapes);
input_tensors[i].shape = net_info_->stages[0].input_shapes[i];
unsigned int input_byte = bmrt_tensor_bytesize(&input_tensors[i]);
bm_memcpy_s2d_partial(handle_, input_tensors[i].device_mem,
(void*)inputs[i].Data(),

View File

@@ -71,6 +71,7 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
valid_kunlunxin_backends = {Backend::LITE};
valid_rknpu_backends = {Backend::RKNPU2};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
initialized = Initialize();
}
@@ -274,7 +275,8 @@ class FASTDEPLOY_DECL CascadeRCNN : public PPDetBase {
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/CascadeRCNN"; }
virtual std::string ModelName() const {
return "PaddleDetection/CascadeRCNN"; }
};
class FASTDEPLOY_DECL PSSDet : public PPDetBase {

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
#include "fastdeploy/vision/utils/utils.h"
@@ -202,10 +203,7 @@ bool PaddleDetPostprocessor::ProcessUnDecodeResults(
static_cast<int32_t>(round(ptr[j * 6])));
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
(*results)[i].boxes.emplace_back(std::array<float, 4>(
{ptr[j * 6 + 2] / GetScaleFactor()[1],
ptr[j * 6 + 3] / GetScaleFactor()[0],
ptr[j * 6 + 4] / GetScaleFactor()[1],
ptr[j * 6 + 5] / GetScaleFactor()[0]}));
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
}
offset += (num_boxes[i] * 6);
}

View File

@@ -34,6 +34,7 @@ Classifier::Classifier(const std::string& model_file,
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;

View File

@@ -34,6 +34,7 @@ DBDetector::DBDetector(const std::string& model_file,
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
runtime_option = custom_option;

View File

@@ -36,6 +36,7 @@ Recognizer::Recognizer(const std::string& model_file,
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
runtime_option = custom_option;