mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
add ocr, ppyoloe, picodet examples (#1076)
* add ocr examples * add ppyoloe examples add picodet examples * remove /ScaleFactor in ppdet/postprocessor.cc
This commit is contained in:
107
examples/vision/detection/paddledetection/sophgo/README.md
Normal file
107
examples/vision/detection/paddledetection/sophgo/README.md
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# PaddleDetection SOPHGO部署示例
|
||||||
|
|
||||||
|
## 支持模型列表
|
||||||
|
|
||||||
|
目前SOPHGO支持如下模型的部署
|
||||||
|
- [PP-YOLOE系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
|
||||||
|
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
|
||||||
|
|
||||||
|
## 准备PP-YOLOE或者PicoDet部署模型以及转换模型
|
||||||
|
|
||||||
|
SOPHGO-TPU部署模型前需要将Paddle模型转换成bmodel模型,具体步骤如下:
|
||||||
|
- Paddle动态图模型转换为ONNX模型,请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md).
|
||||||
|
- ONNX模型转换bmodel模型的过程,请参考[TPU-MLIR](https://github.com/sophgo/tpu-mlir)
|
||||||
|
|
||||||
|
## 模型转换example
|
||||||
|
|
||||||
|
PP-YOLOE和PicoDet模型转换过程类似,下面以ppyoloe_crn_s_300e_coco为例子,教大家如何转换Paddle模型到SOPHGO-TPU模型
|
||||||
|
|
||||||
|
### 导出ONNX模型
|
||||||
|
```shell
|
||||||
|
#导出paddle模型
|
||||||
|
python tools/export_model.py -c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml --output_dir=output_inference -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams
|
||||||
|
|
||||||
|
#paddle模型转ONNX模型
|
||||||
|
paddle2onnx --model_dir ppyoloe_crn_s_300e_coco \
|
||||||
|
--model_filename model.pdmodel \
|
||||||
|
--params_filename model.pdiparams \
|
||||||
|
--save_file ppyoloe_crn_s_300e_coco.onnx \
|
||||||
|
--enable_dev_version True
|
||||||
|
|
||||||
|
#进入Paddle2ONNX文件夹,固定ONNX模型shape
|
||||||
|
python -m paddle2onnx.optimize --input_model ppyoloe_crn_s_300e_coco.onnx \
|
||||||
|
--output_model ppyoloe_crn_s_300e_coco.onnx \
|
||||||
|
--input_shape_dict "{'image':[1,3,640,640]}"
|
||||||
|
|
||||||
|
```
|
||||||
|
### 导出bmodel模型
|
||||||
|
|
||||||
|
以转化BM1684x的bmodel模型为例子,我们需要下载[TPU-MLIR](https://github.com/sophgo/tpu-mlir)工程,安装过程具体参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
|
||||||
|
### 1. 安装
|
||||||
|
``` shell
|
||||||
|
docker pull sophgo/tpuc_dev:latest
|
||||||
|
|
||||||
|
# myname1234是一个示例,也可以设置其他名字
|
||||||
|
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
|
||||||
|
|
||||||
|
source ./envsetup.sh
|
||||||
|
./build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. ONNX模型转换为bmodel模型
|
||||||
|
``` shell
|
||||||
|
mkdir ppyoloe_crn_s_300e_coco && cd ppyoloe_crn_s_300e_coco
|
||||||
|
|
||||||
|
# 下载测试图片,并将图片转换为npz格式
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||||
|
|
||||||
|
#使用python获得模型转换所需要的npz文件
|
||||||
|
im = cv2.imread(im)
|
||||||
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||||
|
#[640 640]为ppyoloe_crn_s_300e_coco的输入大小
|
||||||
|
im_scale_y = 640 / float(im.shape[0])
|
||||||
|
im_scale_x = 640 / float(im.shape[1])
|
||||||
|
inputs = {}
|
||||||
|
inputs['image'] = np.array((im, )).astype('float32')
|
||||||
|
inputs['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
|
||||||
|
np.savez('inputs.npz', image = inputs['image'], scale_factor = inputs['scale_factor'])
|
||||||
|
|
||||||
|
#放入onnx模型文件ppyoloe_crn_s_300e_coco.onnx
|
||||||
|
|
||||||
|
mkdir workspace && cd workspace
|
||||||
|
|
||||||
|
# 将ONNX模型转换为mlir模型
|
||||||
|
model_transform.py \
|
||||||
|
--model_name ppyoloe_crn_s_300e_coco \
|
||||||
|
--model_def ../ppyoloe_crn_s_300e_coco.onnx \
|
||||||
|
--input_shapes [[1,3,640,640],[1,2]] \
|
||||||
|
--keep_aspect_ratio \
|
||||||
|
--pixel_format rgb \
|
||||||
|
--output_names p2o.Div.1,p2o.Concat.29 \
|
||||||
|
--test_input ../inputs.npz \
|
||||||
|
--test_result ppyoloe_crn_s_300e_coco_top_outputs.npz \
|
||||||
|
--mlir ppyoloe_crn_s_300e_coco.mlir
|
||||||
|
```
|
||||||
|
### 注意
|
||||||
|
**由于TPU-MLIR当前不支持后处理算法,所以需要查看后处理的输入作为网络的输出**
|
||||||
|
具体方法为:output_names需要通过[NETRO](https://netron.app/) 查看,网页中打开需要转换的ONNX模型,搜索NonMaxSuppression节点
|
||||||
|
查看INPUTS中boxes和scores的名字,这个两个名字就是我们所需的output_names
|
||||||
|
例如使用Netron可视化后,可以得到如下图片
|
||||||
|

|
||||||
|
找到蓝色方框标记的NonMaxSuppression节点,可以看到红色方框标记的两个节点名称为p2o.Div.1,p2o.Concat.29
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
# 将mlir模型转换为BM1684x的F32 bmodel模型
|
||||||
|
model_deploy.py \
|
||||||
|
--mlir ppyoloe_crn_s_300e_coco.mlir \
|
||||||
|
--quantize F32 \
|
||||||
|
--chip bm1684x \
|
||||||
|
--test_input ppyoloe_crn_s_300e_coco_in_f32.npz \
|
||||||
|
--test_reference ppyoloe_crn_s_300e_coco_top_outputs.npz \
|
||||||
|
--model ppyoloe_crn_s_300e_coco_1684x_f32.bmodel
|
||||||
|
```
|
||||||
|
最终获得可以在BM1684x上能够运行的bmodel模型ppyoloe_crn_s_300e_coco_1684x_f32.bmodel。如果需要进一步对模型进行加速,可以将ONNX模型转换为INT8 bmodel,具体步骤参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
|
||||||
|
|
||||||
|
## 其他链接
|
||||||
|
- [Cpp部署](./cpp)
|
||||||
|
- [python部署](./python)
|
@@ -0,0 +1,19 @@
|
|||||||
|
PROJECT(infer_demo C CXX)
|
||||||
|
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||||
|
# 指定下载解压后的fastdeploy库路径
|
||||||
|
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||||
|
|
||||||
|
set(ENABLE_LITE_BACKEND OFF)
|
||||||
|
#set(FDLIB ${FASTDEPLOY_INSTALL_DIR})
|
||||||
|
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||||
|
|
||||||
|
# 添加FastDeploy依赖头文件
|
||||||
|
include_directories(${FASTDEPLOY_INCS})
|
||||||
|
include_directories(${FastDeploy_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
add_executable(infer_ppyoloe ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
|
||||||
|
add_executable(infer_picodet ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
|
||||||
|
# 添加FastDeploy库依赖
|
||||||
|
target_link_libraries(infer_ppyoloe ${FASTDEPLOY_LIBS})
|
||||||
|
target_link_libraries(infer_picodet ${FASTDEPLOY_LIBS})
|
@@ -0,0 +1,61 @@
|
|||||||
|
# PaddleDetection C++部署示例
|
||||||
|
|
||||||
|
本目录下提供`infer_ppyoloe.cc`和`infer_picodet.cc`快速完成PP-YOLOE模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤:
|
||||||
|
|
||||||
|
1. 软硬件环境满足要求
|
||||||
|
2. 根据开发环境,从头编译FastDeploy仓库
|
||||||
|
|
||||||
|
以上步骤请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)实现
|
||||||
|
|
||||||
|
## 生成基本目录文件
|
||||||
|
|
||||||
|
该例程由以下几个部分组成
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
├── CMakeLists.txt
|
||||||
|
├── build # 编译文件夹
|
||||||
|
├── image # 存放图片的文件夹
|
||||||
|
├── infer_ppyoloe.cc
|
||||||
|
├── infer_picodet.cc
|
||||||
|
└── model # 存放模型文件的文件夹
|
||||||
|
```
|
||||||
|
|
||||||
|
## 编译
|
||||||
|
|
||||||
|
### 编译并拷贝SDK到thirdpartys文件夹
|
||||||
|
|
||||||
|
请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)仓库编译SDK,编译完成后,将在build目录下生成fastdeploy-0.0.3目录.
|
||||||
|
|
||||||
|
### 拷贝模型文件,以及配置文件至model文件夹
|
||||||
|
将Paddle模型转换为SOPHGO bmodel模型,转换步骤参考[文档](../README.md)
|
||||||
|
将转换后的SOPHGO bmodel模型文件拷贝至model中
|
||||||
|
|
||||||
|
### 准备测试图片至image文件夹
|
||||||
|
```bash
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||||
|
cp 000000014439.jpg ./images
|
||||||
|
```
|
||||||
|
|
||||||
|
### 编译example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd build
|
||||||
|
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-0.0.3
|
||||||
|
make
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行例程
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#ppyoloe推理示例
|
||||||
|
./infer_ppyoloe model images/000000014439.jpg
|
||||||
|
|
||||||
|
#picodet推理示例
|
||||||
|
./infer_picodet model images/000000014439.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
- [模型介绍](../../)
|
||||||
|
- [模型转换](../)
|
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision.h"
|
||||||
|
|
||||||
|
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
|
||||||
|
auto model_file = model_dir + "/picodet_s_416_coco_lcnet_1684x_f32.bmodel";
|
||||||
|
auto params_file = "";
|
||||||
|
auto config_file = model_dir + "/infer_cfg.yml";
|
||||||
|
|
||||||
|
auto option = fastdeploy::RuntimeOption();
|
||||||
|
option.UseSophgo();
|
||||||
|
|
||||||
|
auto format = fastdeploy::ModelFormat::SOPHGO;
|
||||||
|
|
||||||
|
auto model = fastdeploy::vision::detection::PicoDet(
|
||||||
|
model_file, params_file, config_file, option, format);
|
||||||
|
|
||||||
|
model.GetPostprocessor().ApplyDecodeAndNMS();
|
||||||
|
|
||||||
|
auto im = cv::imread(image_file);
|
||||||
|
|
||||||
|
fastdeploy::vision::DetectionResult res;
|
||||||
|
if (!model.Predict(&im, &res)) {
|
||||||
|
std::cerr << "Failed to predict." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << res.Str() << std::endl;
|
||||||
|
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
|
||||||
|
cv::imwrite("infer_sophgo.jpg", vis_im);
|
||||||
|
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
if (argc < 3) {
|
||||||
|
std::cout
|
||||||
|
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
|
||||||
|
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
|
||||||
|
<< std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
SophgoInfer(argv[1], argv[2]);
|
||||||
|
return 0;
|
||||||
|
}
|
@@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision.h"
|
||||||
|
|
||||||
|
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
|
||||||
|
auto model_file = model_dir + "/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel";
|
||||||
|
auto params_file = "";
|
||||||
|
auto config_file = model_dir + "/infer_cfg.yml";
|
||||||
|
|
||||||
|
auto option = fastdeploy::RuntimeOption();
|
||||||
|
option.UseSophgo();
|
||||||
|
|
||||||
|
auto format = fastdeploy::ModelFormat::SOPHGO;
|
||||||
|
|
||||||
|
auto model = fastdeploy::vision::detection::PPYOLOE(
|
||||||
|
model_file, params_file, config_file, option, format);
|
||||||
|
|
||||||
|
model.GetPostprocessor().ApplyDecodeAndNMS();
|
||||||
|
|
||||||
|
auto im = cv::imread(image_file);
|
||||||
|
|
||||||
|
fastdeploy::vision::DetectionResult res;
|
||||||
|
if (!model.Predict(&im, &res)) {
|
||||||
|
std::cerr << "Failed to predict." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << res.Str() << std::endl;
|
||||||
|
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
|
||||||
|
cv::imwrite("infer_sophgo.jpg", vis_im);
|
||||||
|
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
if (argc < 3) {
|
||||||
|
std::cout
|
||||||
|
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
|
||||||
|
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
|
||||||
|
<< std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
SophgoInfer(argv[1], argv[2]);
|
||||||
|
return 0;
|
||||||
|
}
|
@@ -0,0 +1,32 @@
|
|||||||
|
# PaddleDetection Python部署示例
|
||||||
|
|
||||||
|
在部署前,需确认以下步骤
|
||||||
|
|
||||||
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/sophgo.md)
|
||||||
|
|
||||||
|
本目录下提供`infer_ppyoloe.py`和`infer_picodet.py`快速完成 PP-YOLOE 和 PicoDet 在SOPHGO TPU上部署的示例。执行如下脚本即可完成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/detection/paddledetection/sophgo/python
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
#ppyoloe推理示例
|
||||||
|
python3 infer_ppyoloe.py --model_file model/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
|
||||||
|
|
||||||
|
#picodet推理示例
|
||||||
|
python3 infer_picodet.py --model_file model/picodet_s_416_coco_lcnet_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
|
||||||
|
|
||||||
|
# 运行完成后返回结果如下所示
|
||||||
|
可视化结果存储在sophgo_result.jpg中
|
||||||
|
```
|
||||||
|
|
||||||
|
## 其它文档
|
||||||
|
- [PP-YOLOE C++部署](../cpp)
|
||||||
|
- [PicoDet C++部署](../cpp)
|
||||||
|
- [转换PicoDet SOPHGO模型文档](../README.md)
|
||||||
|
- [转换PP-YOLOE SOPHGO模型文档](../README.md)
|
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_file", required=True, help="Path of sophgo model.")
|
||||||
|
parser.add_argument("--config_file", required=True, help="Path of config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", type=str, required=True, help="Path of test image file.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
model_file = args.model_file
|
||||||
|
params_file = ""
|
||||||
|
config_file = args.config_file
|
||||||
|
|
||||||
|
# 配置runtime,加载模型
|
||||||
|
runtime_option = fd.RuntimeOption()
|
||||||
|
runtime_option.use_sophgo()
|
||||||
|
|
||||||
|
model = fd.vision.detection.PicoDet(
|
||||||
|
model_file,
|
||||||
|
params_file,
|
||||||
|
config_file,
|
||||||
|
runtime_option=runtime_option,
|
||||||
|
model_format=fd.ModelFormat.SOPHGO)
|
||||||
|
|
||||||
|
model.postprocessor.apply_decode_and_nms()
|
||||||
|
|
||||||
|
# 预测图片分割结果
|
||||||
|
im = cv2.imread(args.image)
|
||||||
|
result = model.predict(im)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||||
|
cv2.imwrite("sophgo_result.jpg", vis_im)
|
||||||
|
print("Visualized result save in ./sophgo_result.jpg")
|
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_file", required=True, help="Path of sophgo model.")
|
||||||
|
parser.add_argument("--config_file", required=True, help="Path of config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", type=str, required=True, help="Path of test image file.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
model_file = args.model_file
|
||||||
|
params_file = ""
|
||||||
|
config_file = args.config_file
|
||||||
|
|
||||||
|
# 配置runtime,加载模型
|
||||||
|
runtime_option = fd.RuntimeOption()
|
||||||
|
runtime_option.use_sophgo()
|
||||||
|
|
||||||
|
model = fd.vision.detection.PPYOLOE(
|
||||||
|
model_file,
|
||||||
|
params_file,
|
||||||
|
config_file,
|
||||||
|
runtime_option=runtime_option,
|
||||||
|
model_format=fd.ModelFormat.SOPHGO)
|
||||||
|
|
||||||
|
model.postprocessor.apply_decode_and_nms()
|
||||||
|
|
||||||
|
# 预测图片分割结果
|
||||||
|
im = cv2.imread(args.image)
|
||||||
|
result = model.predict(im)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||||
|
cv2.imwrite("sophgo_result.jpg", vis_im)
|
||||||
|
print("Visualized result save in ./sophgo_result.jpg")
|
88
examples/vision/ocr/PP-OCRv3/sophgo/README.md
Normal file
88
examples/vision/ocr/PP-OCRv3/sophgo/README.md
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# PPOCRv3 SOPHGO C++部署示例
|
||||||
|
|
||||||
|
## 支持模型列表
|
||||||
|
|
||||||
|
- PP-OCRv3部署模型实现来自[PP-OCR系列模型列表](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md)
|
||||||
|
|
||||||
|
## 准备PPOCRv3部署模型以及转换模型
|
||||||
|
|
||||||
|
PPOCRv3包括文本框检测模型(ch_PP-OCRv3_det)、方向分类模型(ch_ppocr_mobile_v2.0_cls)、文字识别模型(ch_PP-OCRv3_rec)
|
||||||
|
SOPHGO-TPU部署模型前需要将以上Paddle模型转换成bmodel模型,我们以ch_PP-OCRv3_det模型为例,具体步骤如下:
|
||||||
|
- 下载Paddle模型[ch_PP-OCRv3_det](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar)
|
||||||
|
- Pddle模型转换为ONNX模型,请参考[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)
|
||||||
|
- ONNX模型转换bmodel模型的过程,请参考[TPU-MLIR](https://github.com/sophgo/tpu-mlir)
|
||||||
|
|
||||||
|
## 模型转换example
|
||||||
|
|
||||||
|
### 下载ch_PP-OCRv3_det模型,并转换为ONNX模型
|
||||||
|
```shell
|
||||||
|
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
|
||||||
|
tar xvf ch_PP-OCRv3_det_infer.tar
|
||||||
|
|
||||||
|
# 修改ch_PP-OCRv3_det模型的输入shape,由动态输入变成固定输入
|
||||||
|
python paddle_infer_shape.py --model_dir ch_PP-OCRv3_det_infer \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_dir ch_PP-OCRv3_det_infer_fix \
|
||||||
|
--input_shape_dict="{'x':[1,3,960,608]}"
|
||||||
|
|
||||||
|
#将固定输入的Paddle模型转换成ONNX模型
|
||||||
|
paddle2onnx --model_dir ch_PP-OCRv3_det_infer_fix \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_file ch_PP-OCRv3_det_infer_fix.onnx \
|
||||||
|
--enable_dev_version True
|
||||||
|
```
|
||||||
|
|
||||||
|
### 导出bmodel模型
|
||||||
|
|
||||||
|
以转换BM1684x的bmodel模型为例子,我们需要下载[TPU-MLIR](https://github.com/sophgo/tpu-mlir)工程,安装过程具体参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
|
||||||
|
### 1. 安装
|
||||||
|
``` shell
|
||||||
|
docker pull sophgo/tpuc_dev:latest
|
||||||
|
|
||||||
|
# myname1234是一个示例,也可以设置其他名字
|
||||||
|
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
|
||||||
|
|
||||||
|
source ./envsetup.sh
|
||||||
|
./build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. ONNX模型转换为bmodel模型
|
||||||
|
``` shell
|
||||||
|
mkdir ch_PP-OCRv3_det && cd ch_PP-OCRv3_det
|
||||||
|
|
||||||
|
#在该文件中放入测试图片,同时将上一步转换的ch_PP-OCRv3_det_infer_fix.onnx放入该文件夹中
|
||||||
|
cp -rf ${REGRESSION_PATH}/dataset/COCO2017 .
|
||||||
|
cp -rf ${REGRESSION_PATH}/image .
|
||||||
|
#放入onnx模型文件ch_PP-OCRv3_det_infer_fix.onnx
|
||||||
|
|
||||||
|
mkdir workspace && cd workspace
|
||||||
|
|
||||||
|
#将ONNX模型转换为mlir模型,其中参数--output_names可以通过NETRON查看
|
||||||
|
model_transform.py \
|
||||||
|
--model_name ch_PP-OCRv3_det \
|
||||||
|
--model_def ../ch_PP-OCRv3_det_infer_fix.onnx \
|
||||||
|
--input_shapes [[1,3,960,608]] \
|
||||||
|
--mean 0.0,0.0,0.0 \
|
||||||
|
--scale 0.0039216,0.0039216,0.0039216 \
|
||||||
|
--keep_aspect_ratio \
|
||||||
|
--pixel_format rgb \
|
||||||
|
--output_names sigmoid_0.tmp_0 \
|
||||||
|
--test_input ../image/dog.jpg \
|
||||||
|
--test_result ch_PP-OCRv3_det_top_outputs.npz \
|
||||||
|
--mlir ch_PP-OCRv3_det.mlir
|
||||||
|
|
||||||
|
#将mlir模型转换为BM1684x的F32 bmodel模型
|
||||||
|
model_deploy.py \
|
||||||
|
--mlir ch_PP-OCRv3_det.mlir \
|
||||||
|
--quantize F32 \
|
||||||
|
--chip bm1684x \
|
||||||
|
--test_input ch_PP-OCRv3_det_in_f32.npz \
|
||||||
|
--test_reference ch_PP-OCRv3_det_top_outputs.npz \
|
||||||
|
--model ch_PP-OCRv3_det_1684x_f32.bmodel
|
||||||
|
```
|
||||||
|
最终获得可以在BM1684x上能够运行的bmodel模型ch_PP-OCRv3_det_1684x_f32.bmodel。按照上面同样的方法,可以将ch_ppocr_mobile_v2.0_cls,ch_PP-OCRv3_rec转换为bmodel的格式。如果需要进一步对模型进行加速,可以将ONNX模型转换为INT8 bmodel,具体步骤参见[TPU-MLIR文档](https://github.com/sophgo/tpu-mlir/blob/master/README.md)。
|
||||||
|
|
||||||
|
## 其他链接
|
||||||
|
- [Cpp部署](./cpp)
|
13
examples/vision/ocr/PP-OCRv3/sophgo/cpp/CMakeLists.txt
Normal file
13
examples/vision/ocr/PP-OCRv3/sophgo/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
PROJECT(infer_demo C CXX)
|
||||||
|
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||||
|
# 指定下载解压后的fastdeploy库路径
|
||||||
|
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||||
|
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||||
|
|
||||||
|
# 添加FastDeploy依赖头文件
|
||||||
|
include_directories(${FASTDEPLOY_INCS})
|
||||||
|
|
||||||
|
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||||
|
# 添加FastDeploy库依赖
|
||||||
|
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
|
58
examples/vision/ocr/PP-OCRv3/sophgo/cpp/README.md
Normal file
58
examples/vision/ocr/PP-OCRv3/sophgo/cpp/README.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# PPOCRv3 C++部署示例
|
||||||
|
|
||||||
|
本目录下提供`infer.cc`快速完成PPOCRv3模型在SOPHGO BM1684x板子上加速部署的示例。
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤:
|
||||||
|
|
||||||
|
1. 软硬件环境满足要求
|
||||||
|
2. 根据开发环境,从头编译FastDeploy仓库
|
||||||
|
|
||||||
|
以上步骤请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)实现
|
||||||
|
|
||||||
|
## 生成基本目录文件
|
||||||
|
|
||||||
|
该例程由以下几个部分组成
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
├── CMakeLists.txt
|
||||||
|
├── build # 编译文件夹
|
||||||
|
├── image # 存放图片的文件夹
|
||||||
|
├── infer.cc
|
||||||
|
└── model # 存放模型文件的文件夹
|
||||||
|
```
|
||||||
|
|
||||||
|
## 编译
|
||||||
|
|
||||||
|
### 编译并拷贝SDK到thirdpartys文件夹
|
||||||
|
|
||||||
|
请参考[SOPHGO部署库编译](../../../../../../docs/cn/build_and_install/sophgo.md)仓库编译SDK,编译完成后,将在build目录下生成fastdeploy-0.0.3目录.
|
||||||
|
|
||||||
|
### 拷贝bmodel模型文至model文件夹
|
||||||
|
将Paddle模型转换为SOPHGO bmodel模型,转换步骤参考[文档](../README.md)
|
||||||
|
将转换后的SOPHGO bmodel模型文件拷贝至model中
|
||||||
|
|
||||||
|
### 准备测试图片至image文件夹,以及字典文件
|
||||||
|
```bash
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
|
||||||
|
cp 12.jpg image/
|
||||||
|
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 编译example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd build
|
||||||
|
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-0.0.3
|
||||||
|
make
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行例程
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./infer_demo model ./ppocr_keys_v1.txt image/12.jpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
- [模型介绍](../../../)
|
||||||
|
- [模型转换](../)
|
136
examples/vision/ocr/PP-OCRv3/sophgo/cpp/infer.cc
Normal file
136
examples/vision/ocr/PP-OCRv3/sophgo/cpp/infer.cc
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision.h"
|
||||||
|
#ifdef WIN32
|
||||||
|
const char sep = '\\';
|
||||||
|
#else
|
||||||
|
const char sep = '/';
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void InitAndInfer(const std::string& det_model_dir,
|
||||||
|
const std::string& rec_label_file,
|
||||||
|
const std::string& image_file,
|
||||||
|
const fastdeploy::RuntimeOption& option) {
|
||||||
|
auto det_model_file =
|
||||||
|
det_model_dir + sep + "ch_PP-OCRv3_det_1684x_f32.bmodel";
|
||||||
|
auto det_params_file = det_model_dir + sep + "";
|
||||||
|
|
||||||
|
auto cls_model_file =
|
||||||
|
det_model_dir + sep + "ch_ppocr_mobile_v2.0_cls_1684x_f32.bmodel";
|
||||||
|
auto cls_params_file = det_model_dir + sep + "";
|
||||||
|
|
||||||
|
auto rec_model_file =
|
||||||
|
det_model_dir + sep + "ch_PP-OCRv3_rec_1684x_f32.bmodel";
|
||||||
|
auto rec_params_file = det_model_dir + sep + "";
|
||||||
|
|
||||||
|
auto format = fastdeploy::ModelFormat::SOPHGO;
|
||||||
|
|
||||||
|
auto det_option = option;
|
||||||
|
auto cls_option = option;
|
||||||
|
auto rec_option = option;
|
||||||
|
|
||||||
|
// The cls and rec model can inference a batch of images now.
|
||||||
|
// User could initialize the inference batch size and set them after create
|
||||||
|
// PPOCR model.
|
||||||
|
int cls_batch_size = 1;
|
||||||
|
int rec_batch_size = 1;
|
||||||
|
|
||||||
|
// If use TRT backend, the dynamic shape will be set as follow.
|
||||||
|
// We recommend that users set the length and height of the detection model to
|
||||||
|
// a multiple of 32. We also recommend that users set the Trt input shape as
|
||||||
|
// follow.
|
||||||
|
det_option.SetTrtInputShape("x", {1, 3, 64, 64}, {1, 3, 640, 640},
|
||||||
|
{1, 3, 960, 960});
|
||||||
|
cls_option.SetTrtInputShape("x", {1, 3, 48, 10}, {cls_batch_size, 3, 48, 320},
|
||||||
|
{cls_batch_size, 3, 48, 1024});
|
||||||
|
rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {rec_batch_size, 3, 48, 320},
|
||||||
|
{rec_batch_size, 3, 48, 2304});
|
||||||
|
|
||||||
|
// Users could save TRT cache file to disk as follow.
|
||||||
|
// det_option.SetTrtCacheFile(det_model_dir + sep + "det_trt_cache.trt");
|
||||||
|
// cls_option.SetTrtCacheFile(cls_model_dir + sep + "cls_trt_cache.trt");
|
||||||
|
// rec_option.SetTrtCacheFile(rec_model_dir + sep + "rec_trt_cache.trt");
|
||||||
|
|
||||||
|
auto det_model = fastdeploy::vision::ocr::DBDetector(
|
||||||
|
det_model_file, det_params_file, det_option, format);
|
||||||
|
auto cls_model = fastdeploy::vision::ocr::Classifier(
|
||||||
|
cls_model_file, cls_params_file, cls_option, format);
|
||||||
|
auto rec_model = fastdeploy::vision::ocr::Recognizer(
|
||||||
|
rec_model_file, rec_params_file, rec_label_file, rec_option, format);
|
||||||
|
|
||||||
|
// Users could enable static shape infer for rec model when deploy PP-OCR on
|
||||||
|
// hardware which can not support dynamic shape infer well, like Huawei Ascend
|
||||||
|
// series.
|
||||||
|
rec_model.GetPreprocessor().SetStaticShapeInfer(true);
|
||||||
|
rec_model.GetPreprocessor().SetRecImageShape({3, 48, 584});
|
||||||
|
|
||||||
|
assert(det_model.Initialized());
|
||||||
|
assert(cls_model.Initialized());
|
||||||
|
assert(rec_model.Initialized());
|
||||||
|
|
||||||
|
// The classification model is optional, so the PP-OCR can also be connected
|
||||||
|
// in series as follows auto ppocr_v3 =
|
||||||
|
// fastdeploy::pipeline::PPOCRv3(&det_model, &rec_model);
|
||||||
|
auto ppocr_v3 =
|
||||||
|
fastdeploy::pipeline::PPOCRv3(&det_model, &cls_model, &rec_model);
|
||||||
|
|
||||||
|
// Set inference batch size for cls model and rec model, the value could be -1
|
||||||
|
// and 1 to positive infinity. When inference batch size is set to -1, it
|
||||||
|
// means that the inference batch size of the cls and rec models will be the
|
||||||
|
// same as the number of boxes detected by the det model.
|
||||||
|
ppocr_v3.SetClsBatchSize(cls_batch_size);
|
||||||
|
ppocr_v3.SetRecBatchSize(rec_batch_size);
|
||||||
|
|
||||||
|
if (!ppocr_v3.Initialized()) {
|
||||||
|
std::cerr << "Failed to initialize PP-OCR." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto im = cv::imread(image_file);
|
||||||
|
auto im_bak = im.clone();
|
||||||
|
|
||||||
|
fastdeploy::vision::OCRResult result;
|
||||||
|
if (!ppocr_v3.Predict(&im, &result)) {
|
||||||
|
std::cerr << "Failed to predict." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << result.Str() << std::endl;
|
||||||
|
|
||||||
|
auto vis_im = fastdeploy::vision::VisOcr(im_bak, result);
|
||||||
|
cv::imwrite("vis_result.jpg", vis_im);
|
||||||
|
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
if (argc < 4) {
|
||||||
|
std::cout << "Usage: infer_demo path/to/model "
|
||||||
|
"path/to/rec_label_file path/to/image "
|
||||||
|
"e.g ./infer_demo ./ocr_bmodel "
|
||||||
|
"./ppocr_keys_v1.txt ./12.jpg"
|
||||||
|
<< std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fastdeploy::RuntimeOption option;
|
||||||
|
option.UseSophgo();
|
||||||
|
option.UseSophgoBackend();
|
||||||
|
|
||||||
|
std::string model_dir = argv[1];
|
||||||
|
std::string rec_label_file = argv[2];
|
||||||
|
std::string test_image = argv[3];
|
||||||
|
InitAndInfer(model_dir, rec_label_file, test_image, option);
|
||||||
|
return 0;
|
||||||
|
}
|
38
examples/vision/ocr/PP-OCRv3/sophgo/python/README.md
Normal file
38
examples/vision/ocr/PP-OCRv3/sophgo/python/README.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# PPOCRv3 Python部署示例
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤
|
||||||
|
|
||||||
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/sophgo.md)
|
||||||
|
|
||||||
|
本目录下提供`infer.py`快速完成 PPOCRv3 在SOPHGO TPU上部署的示例。执行如下脚本即可完成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/ocr/PP-OCRv3/sophgo/python
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
|
||||||
|
|
||||||
|
#下载字典文件
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
python3 infer.py --det_model ocr_bmodel/ch_PP-OCRv3_det_1684x_f32.bmodel \
|
||||||
|
--cls_model ocr_bmodel/ch_ppocr_mobile_v2.0_cls_1684x_f32.bmodel \
|
||||||
|
--rec_model ocr_bmodel/ch_PP-OCRv3_rec_1684x_f32.bmodel \
|
||||||
|
--rec_label_file ../ppocr_keys_v1.txt \
|
||||||
|
--image ../12.jpg
|
||||||
|
|
||||||
|
# 运行完成后返回结果如下所示
|
||||||
|
det boxes: [[42,413],[483,391],[484,428],[43,450]]rec text: 上海斯格威铂尔大酒店 rec score:0.952958 cls label: 0 cls score: 1.000000
|
||||||
|
det boxes: [[187,456],[399,448],[400,480],[188,488]]rec text: 打浦路15号 rec score:0.897335 cls label: 0 cls score: 1.000000
|
||||||
|
det boxes: [[23,507],[513,488],[515,529],[24,548]]rec text: 绿洲仕格维花园公寓 rec score:0.994589 cls label: 0 cls score: 1.000000
|
||||||
|
det boxes: [[74,553],[427,542],[428,571],[75,582]]rec text: 打浦路252935号 rec score:0.900663 cls label: 0 cls score: 1.000000
|
||||||
|
|
||||||
|
可视化结果保存在sophgo_result.jpg中
|
||||||
|
```
|
||||||
|
|
||||||
|
## 其它文档
|
||||||
|
- [PPOCRv3 C++部署](../cpp)
|
||||||
|
- [转换 PPOCRv3 SOPHGO模型文档](../README.md)
|
116
examples/vision/ocr/PP-OCRv3/sophgo/python/infer.py
Normal file
116
examples/vision/ocr/PP-OCRv3/sophgo/python/infer.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--det_model", required=True, help="Path of Detection model of PPOCR.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cls_model",
|
||||||
|
required=True,
|
||||||
|
help="Path of Classification model of PPOCR.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--rec_model",
|
||||||
|
required=True,
|
||||||
|
help="Path of Recognization model of PPOCR.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--rec_label_file",
|
||||||
|
required=True,
|
||||||
|
help="Path of Recognization label of PPOCR.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", type=str, required=True, help="Path of test image file.")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# 配置runtime,加载模型
|
||||||
|
runtime_option = fd.RuntimeOption()
|
||||||
|
runtime_option.use_sophgo()
|
||||||
|
|
||||||
|
# Detection模型, 检测文字框
|
||||||
|
det_model_file = args.det_model
|
||||||
|
det_params_file = ""
|
||||||
|
# Classification模型,方向分类,可选
|
||||||
|
cls_model_file = args.cls_model
|
||||||
|
cls_params_file = ""
|
||||||
|
# Recognition模型,文字识别模型
|
||||||
|
rec_model_file = args.rec_model
|
||||||
|
rec_params_file = ""
|
||||||
|
rec_label_file = args.rec_label_file
|
||||||
|
|
||||||
|
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
|
||||||
|
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
|
||||||
|
cls_batch_size = 1
|
||||||
|
rec_batch_size = 1
|
||||||
|
|
||||||
|
# 当使用TRT时,分别给三个模型的runtime设置动态shape,并完成模型的创建.
|
||||||
|
# 注意: 需要在检测模型创建完成后,再设置分类模型的动态输入并创建分类模型, 识别模型同理.
|
||||||
|
# 如果用户想要自己改动检测模型的输入shape, 我们建议用户把检测模型的长和高设置为32的倍数.
|
||||||
|
det_option = runtime_option
|
||||||
|
det_option.set_trt_input_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
|
||||||
|
[1, 3, 960, 960])
|
||||||
|
# 用户可以把TRT引擎文件保存至本地
|
||||||
|
# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
|
||||||
|
det_model = fd.vision.ocr.DBDetector(
|
||||||
|
det_model_file,
|
||||||
|
det_params_file,
|
||||||
|
runtime_option=det_option,
|
||||||
|
model_format=fd.ModelFormat.SOPHGO)
|
||||||
|
|
||||||
|
cls_option = runtime_option
|
||||||
|
cls_option.set_trt_input_shape("x", [1, 3, 48, 10],
|
||||||
|
[cls_batch_size, 3, 48, 320],
|
||||||
|
[cls_batch_size, 3, 48, 1024])
|
||||||
|
# 用户可以把TRT引擎文件保存至本地
|
||||||
|
# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
|
||||||
|
cls_model = fd.vision.ocr.Classifier(
|
||||||
|
cls_model_file,
|
||||||
|
cls_params_file,
|
||||||
|
runtime_option=cls_option,
|
||||||
|
model_format=fd.ModelFormat.SOPHGO)
|
||||||
|
|
||||||
|
rec_option = runtime_option
|
||||||
|
rec_option.set_trt_input_shape("x", [1, 3, 48, 10],
|
||||||
|
[rec_batch_size, 3, 48, 320],
|
||||||
|
[rec_batch_size, 3, 48, 2304])
|
||||||
|
# 用户可以把TRT引擎文件保存至本地
|
||||||
|
# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
|
||||||
|
rec_model = fd.vision.ocr.Recognizer(
|
||||||
|
rec_model_file,
|
||||||
|
rec_params_file,
|
||||||
|
rec_label_file,
|
||||||
|
runtime_option=rec_option,
|
||||||
|
model_format=fd.ModelFormat.SOPHGO)
|
||||||
|
|
||||||
|
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
|
||||||
|
ppocr_v3 = fd.vision.ocr.PPOCRv3(
|
||||||
|
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
|
||||||
|
|
||||||
|
# 需要使用下行代码, 来启用rec模型的静态shape推理,这里rec模型的静态输入为[3, 48, 584]
|
||||||
|
rec_model.preprocessor.static_shape_infer = True
|
||||||
|
rec_model.preprocessor.rec_image_shape = [3, 48, 584]
|
||||||
|
|
||||||
|
# 给cls和rec模型设置推理时的batch size
|
||||||
|
# 此值能为-1, 和1到正无穷
|
||||||
|
# 当此值为-1时, cls和rec模型的batch size将默认和det模型检测出的框的数量相同
|
||||||
|
ppocr_v3.cls_batch_size = cls_batch_size
|
||||||
|
ppocr_v3.rec_batch_size = rec_batch_size
|
||||||
|
|
||||||
|
# 预测图片准备
|
||||||
|
im = cv2.imread(args.image)
|
||||||
|
|
||||||
|
#预测并打印结果
|
||||||
|
result = ppocr_v3.predict(im)
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
vis_im = fd.vision.vis_ppocr(im, result)
|
||||||
|
cv2.imwrite("sophgo_result.jpg", vis_im)
|
||||||
|
print("Visualized result save in ./sophgo_result.jpg")
|
@@ -71,8 +71,8 @@ bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
|
|||||||
}
|
}
|
||||||
} else if (use_sophgotpu) {
|
} else if (use_sophgotpu) {
|
||||||
if (!IsSupported(valid_sophgonpu_backends, runtime_option.backend)) {
|
if (!IsSupported(valid_sophgonpu_backends, runtime_option.backend)) {
|
||||||
FDERROR << "The valid rknpu backends of model " << ModelName() << " are "
|
FDERROR << "The valid sophgo backends of model " << ModelName() << " are "
|
||||||
<< Str(valid_rknpu_backends) << ", " << runtime_option.backend
|
<< Str(valid_sophgonpu_backends) << ", " << runtime_option.backend
|
||||||
<< " is not supported." << std::endl;
|
<< " is not supported." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@@ -180,7 +180,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
assert(BM_SUCCESS == status);
|
assert(BM_SUCCESS == status);
|
||||||
input_tensors[i].dtype = input_dtypes[i];
|
input_tensors[i].dtype = input_dtypes[i];
|
||||||
input_tensors[i].st_mode = BM_STORE_1N;
|
input_tensors[i].st_mode = BM_STORE_1N;
|
||||||
input_tensors[i].shape = *(net_info_->stages[i].input_shapes);
|
input_tensors[i].shape = net_info_->stages[0].input_shapes[i];
|
||||||
unsigned int input_byte = bmrt_tensor_bytesize(&input_tensors[i]);
|
unsigned int input_byte = bmrt_tensor_bytesize(&input_tensors[i]);
|
||||||
bm_memcpy_s2d_partial(handle_, input_tensors[i].device_mem,
|
bm_memcpy_s2d_partial(handle_, input_tensors[i].device_mem,
|
||||||
(void*)inputs[i].Data(),
|
(void*)inputs[i].Data(),
|
||||||
|
@@ -71,6 +71,7 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase {
|
|||||||
valid_kunlunxin_backends = {Backend::LITE};
|
valid_kunlunxin_backends = {Backend::LITE};
|
||||||
valid_rknpu_backends = {Backend::RKNPU2};
|
valid_rknpu_backends = {Backend::RKNPU2};
|
||||||
valid_ascend_backends = {Backend::LITE};
|
valid_ascend_backends = {Backend::LITE};
|
||||||
|
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,7 +275,8 @@ class FASTDEPLOY_DECL CascadeRCNN : public PPDetBase {
|
|||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::string ModelName() const { return "PaddleDetection/CascadeRCNN"; }
|
virtual std::string ModelName() const {
|
||||||
|
return "PaddleDetection/CascadeRCNN"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class FASTDEPLOY_DECL PSSDet : public PPDetBase {
|
class FASTDEPLOY_DECL PSSDet : public PPDetBase {
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
@@ -202,21 +203,18 @@ bool PaddleDetPostprocessor::ProcessUnDecodeResults(
|
|||||||
static_cast<int32_t>(round(ptr[j * 6])));
|
static_cast<int32_t>(round(ptr[j * 6])));
|
||||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||||
{ptr[j * 6 + 2] / GetScaleFactor()[1],
|
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||||
ptr[j * 6 + 3] / GetScaleFactor()[0],
|
|
||||||
ptr[j * 6 + 4] / GetScaleFactor()[1],
|
|
||||||
ptr[j * 6 + 5] / GetScaleFactor()[0]}));
|
|
||||||
}
|
}
|
||||||
offset += (num_boxes[i] * 6);
|
offset += (num_boxes[i] * 6);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> PaddleDetPostprocessor::GetScaleFactor(){
|
std::vector<float> PaddleDetPostprocessor::GetScaleFactor() {
|
||||||
return scale_factor_;
|
return scale_factor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value){
|
void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value) {
|
||||||
for (int i = 0; i < scale_factor_.size(); ++i) {
|
for (int i = 0; i < scale_factor_.size(); ++i) {
|
||||||
scale_factor_[i] = scale_factor_value[i];
|
scale_factor_[i] = scale_factor_value[i];
|
||||||
}
|
}
|
||||||
|
@@ -34,6 +34,7 @@ Classifier::Classifier(const std::string& model_file,
|
|||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||||
valid_kunlunxin_backends = {Backend::LITE};
|
valid_kunlunxin_backends = {Backend::LITE};
|
||||||
valid_ascend_backends = {Backend::LITE};
|
valid_ascend_backends = {Backend::LITE};
|
||||||
|
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||||
}
|
}
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
|
@@ -34,6 +34,7 @@ DBDetector::DBDetector(const std::string& model_file,
|
|||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||||
valid_kunlunxin_backends = {Backend::LITE};
|
valid_kunlunxin_backends = {Backend::LITE};
|
||||||
valid_ascend_backends = {Backend::LITE};
|
valid_ascend_backends = {Backend::LITE};
|
||||||
|
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
|
@@ -36,6 +36,7 @@ Recognizer::Recognizer(const std::string& model_file,
|
|||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||||
valid_kunlunxin_backends = {Backend::LITE};
|
valid_kunlunxin_backends = {Backend::LITE};
|
||||||
valid_ascend_backends = {Backend::LITE};
|
valid_ascend_backends = {Backend::LITE};
|
||||||
|
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
|
Reference in New Issue
Block a user