mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Backend And DOC] 改进ppseg文档 + 为RKNPU2后端新增对多输入模型的支持 (#491)
* 11-02/14:35 * 新增输入数据format错误判断 * 优化推理过程,减少内存分配次数 * 支持多输入rknn模型 * rknn模型输出shape为三维时,输出将被强制对齐为4纬。现在将直接抹除rknn补充的shape,方便部分对输出shape进行判断的模型进行正确的后处理。 * 11-03/17:25 * 支持导出多输入RKNN模型 * 更新各种文档 * ppseg改用Fastdeploy中的模型进行转换 * 11-03/17:25 * 新增开源头 * 11-03/21:48 * 删除无用debug代码,补充注释
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
# RKNPU2模型部署
|
||||
|
||||
## 安装环境
|
||||
RKNPU2模型导出只支持在x86Linux平台上进行导出,安装流程请参考[RKNPU2模型导出环境配置文档](./install_rknn_toolkit2.md)
|
||||
|
||||
## ONNX模型转换为RKNN模型
|
||||
ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型转换为RKNN模型,具体流程请查看[转换文档](./export.md)
|
||||
|
||||
@@ -61,4 +64,3 @@ int infer_scrfd_npu() {
|
||||
- [rknpu2板端环境安装配置](../../build_and_install/rknpu2.md)
|
||||
- [rknn_toolkit2安装文档](./install_rknn_toolkit2.md)
|
||||
- [onnx转换rknn文档](./export.md)
|
||||
|
||||
|
@@ -4,49 +4,103 @@
|
||||
|
||||
- [PaddleSeg develop](https://github.com/PaddlePaddle/PaddleSeg/tree/develop)
|
||||
|
||||
目前FastDeploy支持如下模型的部署
|
||||
目前FastDeploy使用RKNPU2推理PPSeg支持如下模型的部署:
|
||||
|
||||
- [U-Net系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/unet/README.md)
|
||||
- [PP-LiteSeg系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/pp_liteseg/README.md)
|
||||
- [PP-HumanSeg系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/contrib/PP-HumanSeg/README.md)
|
||||
- [FCN系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/fcn/README.md)
|
||||
- [DeepLabV3系列模型](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/configs/deeplabv3/README.md)
|
||||
|
||||
【注意】如你部署的为**PP-Matting**、**PP-HumanMatting**以及**ModNet**请参考[Matting模型部署](../../matting)
|
||||
| 模型 | 参数文件大小 | 输入Shape | mIoU | mIoU (flip) | mIoU (ms+flip) |
|
||||
|:---------------------------------------------------------------------------------------------------------------------------------------------|:-------|:---------|:-------|:------------|:---------------|
|
||||
| [Unet-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Unet_cityscapes_without_argmax_infer.tgz) | 52MB | 1024x512 | 65.00% | 66.02% | 66.89% |
|
||||
| [PP-LiteSeg-T(STDC1)-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_infer.tgz) | 31MB | 1024x512 | 77.04% | 77.73% | 77.46% |
|
||||
| [PP-HumanSegV1-Lite(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Lite_infer.tgz) | 543KB | 192x192 | 86.2% | - | - |
|
||||
| [PP-HumanSegV2-Lite(通用人像分割模型)](https://bj.bcebos.com/paddle2onnx/libs/PP_HumanSegV2_Lite_192x192_infer.tgz) | 12MB | 192x192 | 92.52% | - | - |
|
||||
| [PP-HumanSegV2-Mobile(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV2_Mobile_192x192_infer.tgz) | 29MB | 192x192 | 93.13% | - | - |
|
||||
| [PP-HumanSegV1-Server(通用人像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/PP_HumanSegV1_Server_infer.tgz) | 103MB | 512x512 | 96.47% | - | - |
|
||||
| [Portait-PP-HumanSegV2_Lite(肖像分割模型)](https://bj.bcebos.com/paddlehub/fastdeploy/Portrait_PP_HumanSegV2_Lite_256x144_infer.tgz) | 3.6M | 256x144 | 96.63% | - | - |
|
||||
| [FCN-HRNet-W18-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/FCN_HRNet_W18_cityscapes_without_argmax_infer.tgz) | 37MB | 1024x512 | 78.97% | 79.49% | 79.74% |
|
||||
| [Deeplabv3-ResNet101-OS8-cityscapes](https://bj.bcebos.com/paddlehub/fastdeploy/Deeplabv3_ResNet101_OS8_cityscapes_without_argmax_infer.tgz) | 150MB | 1024x512 | 79.90% | 80.22% | 80.47% |
|
||||
|
||||
## 准备PaddleSeg部署模型以及转换模型
|
||||
RKNPU部署模型前需要将Paddle模型转换成RKNN模型,具体步骤如下:
|
||||
* Paddle动态图模型转换为ONNX模型,请参考[PaddleSeg模型导出说明](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/contrib/PP-HumanSeg)
|
||||
* ONNX模型转换RKNN模型的过程,请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。
|
||||
|
||||
RKNPU部署模型前需要将模型转换成RKNN模型,其过程一般可以简化为如下步骤:
|
||||
* Paddle动态图模型 -> ONNX模型 -> RKNN模型。
|
||||
* Paddle动态图模型 转换为 ONNX模型的过程请参考([PaddleSeg模型导出说明](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/contrib/PP-HumanSeg))。
|
||||
* 对于ONNX模型 转换 RKNN模型的过程,请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。
|
||||
以PPHumanSeg为例,在获取到ONNX模型后,其转换为RK3588步骤如下:
|
||||
* 编写config.yaml文件
|
||||
```yaml
|
||||
model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx
|
||||
output_folder: ./
|
||||
target_platform: RK3588
|
||||
normalize:
|
||||
mean: [0.5,0.5,0.5]
|
||||
std: [0.5,0.5,0.5]
|
||||
outputs: None
|
||||
```
|
||||
* 执行转换代码
|
||||
## 模型转换example
|
||||
|
||||
下面以Portait-PP-HumanSegV2_Lite(肖像分割模型)为例子,教大家如何转换PPSeg模型到RKNN模型。
|
||||
```bash
|
||||
python /path/to/fastDeploy/toosl/export.py --config_path=/path/to/fastdeploy/tools/rknpu2/config/ppset_config.yaml
|
||||
# 下载Paddle2ONNX仓库
|
||||
git clone https://github.com/PaddlePaddle/Paddle2ONNX
|
||||
|
||||
# 下载Paddle静态图模型并为Paddle静态图模型固定输入shape
|
||||
## 进入为Paddle静态图模型固定输入shape的目录
|
||||
cd Paddle2ONNX/tools/paddle
|
||||
## 下载Paddle静态图模型并解压
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/Portrait_PP_HumanSegV2_Lite_256x144_infer.tgz
|
||||
tar xvf Portrait_PP_HumanSegV2_Lite_256x144_infer.tgz
|
||||
python paddle_infer_shape.py --model_dir Portrait_PP_HumanSegV2_Lite_256x144_infer/ \
|
||||
--model_filename model.pdmodel \
|
||||
--params_filename model.pdiparams \
|
||||
--save_dir Portrait_PP_HumanSegV2_Lite_256x144_infer \
|
||||
--input_shape_dict="{'x':[1,3,144,256]}"
|
||||
|
||||
# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐
|
||||
paddle2onnx --model_dir Portrait_PP_HumanSegV2_Lite_256x144_infer \
|
||||
--model_filename model.pdmodel \
|
||||
--params_filename model.pdiparams \
|
||||
--save_file Portrait_PP_HumanSegV2_Lite_256x144_infer/Portrait_PP_HumanSegV2_Lite_256x144_infer.onnx \
|
||||
--enable_dev_version True
|
||||
|
||||
# ONNX模型转RKNN模型
|
||||
# 将ONNX模型目录拷贝到Fastdeploy根目录
|
||||
cp -r ./Portrait_PP_HumanSegV2_Lite_256x144_infer /path/to/Fastdeploy
|
||||
# 转换模型,模型将生成在Portrait_PP_HumanSegV2_Lite_256x144_infer目录下
|
||||
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/Portrait_PP_HumanSegV2_Lite_256x144_infer.yaml
|
||||
```
|
||||
|
||||
## 下载预训练模型
|
||||
## 修改yaml配置文件
|
||||
|
||||
为了方便开发者的测试,下面提供了PaddleSeg导出的部分模型(导出方式为:**指定**`--input_shape`,**指定**`--output_op none`,**指定**`--without_argmax`),开发者可直接下载使用。
|
||||
在**模型转换example**中,我们对模型的shape进行了固定,因此对应的yaml文件也要进行修改,如下:
|
||||
|
||||
| 任务场景 | 模型 | 模型版本(表示已经测试的版本) | 大小 | ONNX/RKNN是否支持 | ONNX/RKNN速度(ms) |
|
||||
|------------------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------------|-----|---------------|-----------------|
|
||||
| Segmentation | PP-LiteSeg | [PP_LiteSeg_T_STDC1_cityscapes](https://bj.bcebos.com/fastdeploy/models/rknn2/PP_LiteSeg_T_STDC1_cityscapes_without_argmax_infer_3588.tgz) | - | True/True | 6634/5598 |
|
||||
| Segmentation | PP-HumanSegV2Lite | [portrait](https://bj.bcebos.com/fastdeploy/models/rknn2/portrait_pp_humansegv2_lite_256x144_inference_model_without_softmax_3588.tgz) | - | True/True | 456/266 |
|
||||
| Segmentation | PP-HumanSegV2Lite | [human](https://bj.bcebos.com/fastdeploy/models/rknn2/human_pp_humansegv2_lite_192x192_pretrained_3588.tgz) | - | True/True | 496/256 |
|
||||
**原yaml文件**
|
||||
```yaml
|
||||
Deploy:
|
||||
input_shape:
|
||||
- -1
|
||||
- 3
|
||||
- -1
|
||||
- -1
|
||||
model: model.pdmodel
|
||||
output_dtype: float32
|
||||
output_op: none
|
||||
params: model.pdiparams
|
||||
transforms:
|
||||
- target_size:
|
||||
- 256
|
||||
- 144
|
||||
type: Resize
|
||||
- type: Normalize
|
||||
```
|
||||
|
||||
**修改后的yaml文件**
|
||||
```yaml
|
||||
Deploy:
|
||||
input_shape:
|
||||
- 1
|
||||
- 3
|
||||
- 144
|
||||
- 256
|
||||
model: model.pdmodel
|
||||
output_dtype: float32
|
||||
output_op: none
|
||||
params: model.pdiparams
|
||||
transforms:
|
||||
- target_size:
|
||||
- 256
|
||||
- 144
|
||||
type: Resize
|
||||
- type: Normalize
|
||||
```
|
||||
|
||||
## 详细部署文档
|
||||
- [RKNN总体部署教程](../../../../../docs/cn/faq/rknpu2.md)
|
||||
- [RKNN总体部署教程](../../../../../docs/cn/faq/rknpu2/rknpu2.md)
|
||||
- [C++部署](cpp)
|
||||
- [Python部署](python)
|
@@ -41,13 +41,7 @@ fastdeploy-0.0.3目录,请移动它至thirdpartys目录下.
|
||||
|
||||
### 拷贝模型文件,以及配置文件至model文件夹
|
||||
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。
|
||||
转换为RKNN后的模型文件也需要拷贝至model,这里提供了转换好的文件,输入以下命令下载使用(模型文件为RK3588,RK3568需要重新[转换PPSeg RKNN模型](../README.md))。
|
||||
```bash
|
||||
cd model
|
||||
wget https://bj.bcebos.com/fastdeploy/models/rknn2/human_pp_humansegv2_lite_192x192_pretrained_3588.tgz
|
||||
tar xvf human_pp_humansegv2_lite_192x192_pretrained_3588.tgz
|
||||
cp -r ./human_pp_humansegv2_lite_192x192_pretrained_3588 ./model
|
||||
```
|
||||
转换为RKNN后的模型文件也需要拷贝至model,输入以下命令下载使用(模型文件为RK3588,RK3568需要重新[转换PPSeg RKNN模型](../README.md))。
|
||||
|
||||
### 准备测试图片至image文件夹
|
||||
```bash
|
||||
|
@@ -1,3 +1,16 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "fastdeploy/vision.h"
|
||||
@@ -40,11 +53,11 @@ std::string GetModelPath(std::string& model_path, const std::string& device) {
|
||||
|
||||
void InferHumanPPHumansegv2Lite(const std::string& device) {
|
||||
std::string model_file =
|
||||
"./model/human_pp_humansegv2_lite_192x192_pretrained_3588/"
|
||||
"human_pp_humansegv2_lite_192x192_pretrained_3588.";
|
||||
"./model/Portrait_PP_HumanSegV2_Lite_256x144_infer/"
|
||||
"Portrait_PP_HumanSegV2_Lite_256x144_infer_rk3588.";
|
||||
std::string params_file;
|
||||
std::string config_file =
|
||||
"./model/human_pp_humansegv2_lite_192x192_pretrained_3588/deploy.yaml";
|
||||
"./model/Portrait_PP_HumanSegV2_Lite_256x144_infer/deploy.yaml";
|
||||
|
||||
fastdeploy::RuntimeOption option = GetOption(device);
|
||||
fastdeploy::ModelFormat format = GetFormat(device);
|
||||
|
@@ -13,17 +13,13 @@
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd FastDeploy/examples/vision/segmentation/paddleseg/python
|
||||
|
||||
# 下载模型
|
||||
wget https://bj.bcebos.com/fastdeploy/models/rknn2/human_pp_humansegv2_lite_192x192_pretrained_3588.tgz
|
||||
tar xvf human_pp_humansegv2_lite_192x192_pretrained_3588.tgz
|
||||
|
||||
# 下载图片
|
||||
wget https://paddleseg.bj.bcebos.com/dygraph/pp_humanseg_v2/images.zip
|
||||
unzip images.zip
|
||||
|
||||
# 推理
|
||||
python3 infer.py --model_file ./human_pp_humansegv2_lite_192x192_pretrained_3588/human_pp_humansegv2_lite_192x192_pretrained_3588.rknn \
|
||||
--config_file ./human_pp_humansegv2_lite_192x192_pretrained_3588/deploy.yaml \
|
||||
python3 infer.py --model_file ./Portrait_PP_HumanSegV2_Lite_256x144_infer/Portrait_PP_HumanSegV2_Lite_256x144_infer_rk3588.rknn \
|
||||
--config_file ./Portrait_PP_HumanSegV2_Lite_256x144_infer/deploy.yaml \
|
||||
--image images/portrait_heng.jpg
|
||||
```
|
||||
|
||||
|
@@ -1,3 +1,16 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
@@ -30,7 +43,11 @@ model_file = args.model_file
|
||||
params_file = ""
|
||||
config_file = args.config_file
|
||||
model = fd.vision.segmentation.PaddleSegModel(
|
||||
model_file, params_file, config_file, runtime_option=runtime_option,model_format=fd.ModelFormat.RKNN)
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=runtime_option,
|
||||
model_format=fd.ModelFormat.RKNN)
|
||||
|
||||
model.disable_normalize_and_permute()
|
||||
|
||||
|
@@ -15,11 +15,27 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
RKNPU2Backend::~RKNPU2Backend() {
|
||||
if (input_attrs != nullptr) {
|
||||
free(input_attrs);
|
||||
// Release memory uniformly here
|
||||
if (input_attrs_ != nullptr) {
|
||||
free(input_attrs_);
|
||||
}
|
||||
if (output_attrs != nullptr) {
|
||||
free(output_attrs);
|
||||
|
||||
if (output_attrs_ != nullptr) {
|
||||
free(output_attrs_);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||
rknn_destroy_mem(ctx, input_mems_[i]);
|
||||
}
|
||||
if(input_mems_ != nullptr){
|
||||
free(input_mems_);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
||||
rknn_destroy_mem(ctx, output_mems_[i]);
|
||||
}
|
||||
if(output_mems_ != nullptr){
|
||||
free(output_mems_);
|
||||
}
|
||||
}
|
||||
/***************************************************************
|
||||
@@ -150,56 +166,85 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
||||
}
|
||||
|
||||
// Get detailed input parameters
|
||||
input_attrs =
|
||||
input_attrs_ =
|
||||
(rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_input);
|
||||
memset(input_attrs, 0, io_num.n_input * sizeof(rknn_tensor_attr));
|
||||
memset(input_attrs_, 0, io_num.n_input * sizeof(rknn_tensor_attr));
|
||||
inputs_desc_.resize(io_num.n_input);
|
||||
|
||||
// create input tensor memory
|
||||
// rknn_tensor_mem* input_mems[io_num.n_input];
|
||||
input_mems_ = (rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_input);
|
||||
|
||||
// get input info and copy to input tensor info
|
||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||
input_attrs[i].index = i;
|
||||
input_attrs_[i].index = i;
|
||||
// query info
|
||||
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &(input_attrs[i]),
|
||||
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &(input_attrs_[i]),
|
||||
sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
printf("rknn_init error! ret=%d\n", ret);
|
||||
return false;
|
||||
}
|
||||
std::string temp_name = input_attrs[i].name;
|
||||
std::vector<int> temp_shape{};
|
||||
temp_shape.resize(input_attrs[i].n_dims);
|
||||
for (int j = 0; j < input_attrs[i].n_dims; j++) {
|
||||
temp_shape[j] = (int)input_attrs[i].dims[j];
|
||||
if((input_attrs_[i].fmt != RKNN_TENSOR_NHWC) &&
|
||||
(input_attrs_[i].fmt != RKNN_TENSOR_UNDEFINED)){
|
||||
FDERROR << "rknpu2_backend only support input format is NHWC or UNDEFINED" << std::endl;
|
||||
}
|
||||
|
||||
// copy input_attrs_ to input tensor info
|
||||
std::string temp_name = input_attrs_[i].name;
|
||||
std::vector<int> temp_shape{};
|
||||
temp_shape.resize(input_attrs_[i].n_dims);
|
||||
for (int j = 0; j < input_attrs_[i].n_dims; j++) {
|
||||
temp_shape[j] = (int)input_attrs_[i].dims[j];
|
||||
}
|
||||
FDDataType temp_dtype =
|
||||
fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(
|
||||
input_attrs[i].type);
|
||||
input_attrs_[i].type);
|
||||
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
||||
inputs_desc_[i] = temp_input_info;
|
||||
}
|
||||
|
||||
// Get detailed output parameters
|
||||
output_attrs =
|
||||
output_attrs_ =
|
||||
(rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_output);
|
||||
memset(output_attrs, 0, io_num.n_output * sizeof(rknn_tensor_attr));
|
||||
memset(output_attrs_, 0, io_num.n_output * sizeof(rknn_tensor_attr));
|
||||
outputs_desc_.resize(io_num.n_output);
|
||||
|
||||
// Create output tensor memory
|
||||
output_mems_ = (rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_output);;
|
||||
|
||||
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
||||
output_attrs[i].index = i;
|
||||
output_attrs_[i].index = i;
|
||||
// query info
|
||||
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &(output_attrs[i]),
|
||||
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &(output_attrs_[i]),
|
||||
sizeof(rknn_tensor_attr));
|
||||
if (ret != RKNN_SUCC) {
|
||||
FDERROR << "rknn_query fail! ret = " << ret << std::endl;
|
||||
return false;
|
||||
}
|
||||
std::string temp_name = output_attrs[i].name;
|
||||
std::vector<int> temp_shape{};
|
||||
temp_shape.resize(output_attrs[i].n_dims);
|
||||
for (int j = 0; j < output_attrs[i].n_dims; j++) {
|
||||
temp_shape[j] = (int)output_attrs[i].dims[j];
|
||||
|
||||
// If the output dimension is 3, the runtime will automatically change it to 4.
|
||||
// Obviously, this is wrong, and manual correction is required here.
|
||||
int n_dims = output_attrs_[i].n_dims;
|
||||
if((n_dims == 4) && (output_attrs_[i].dims[3] == 1)){
|
||||
n_dims--;
|
||||
FDWARNING << "The output["
|
||||
<< i
|
||||
<< "].shape[3] is 1, remove this dim."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// copy output_attrs_ to output tensor
|
||||
std::string temp_name = output_attrs_[i].name;
|
||||
std::vector<int> temp_shape{};
|
||||
temp_shape.resize(n_dims);
|
||||
for (int j = 0; j < n_dims; j++) {
|
||||
temp_shape[j] = (int)output_attrs_[i].dims[j];
|
||||
}
|
||||
|
||||
FDDataType temp_dtype =
|
||||
fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(
|
||||
output_attrs[i].type);
|
||||
output_attrs_[i].type);
|
||||
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
||||
outputs_desc_[i] = temp_input_info;
|
||||
}
|
||||
@@ -254,75 +299,50 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
return false;
|
||||
}
|
||||
|
||||
// the input size only can be one
|
||||
if (inputs.size() > 1) {
|
||||
FDERROR << "[RKNPU2Backend] Size of the inputs only support 1."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!this->infer_init){
|
||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||
// Judge whether the input and output types are the same
|
||||
rknn_tensor_type input_type =
|
||||
fastdeploy::RKNPU2Backend::FDDataTypeToRknnTensorType(inputs[0].dtype);
|
||||
if (input_type != input_attrs[0].type) {
|
||||
fastdeploy::RKNPU2Backend::FDDataTypeToRknnTensorType(inputs[i].dtype);
|
||||
if (input_type != input_attrs_[i].type) {
|
||||
FDWARNING << "The input tensor type != model's inputs type."
|
||||
<< "The input_type need " << get_type_string(input_attrs[0].type)
|
||||
<< ",but inputs[0].type is " << get_type_string(input_type)
|
||||
<< "The input_type need " << get_type_string(input_attrs_[i].type)
|
||||
<< ",but inputs["<< i << "].type is " << get_type_string(input_type)
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
rknn_tensor_format input_layout =
|
||||
RKNN_TENSOR_NHWC; // RK3588 only support NHWC
|
||||
input_attrs[0].type = input_type;
|
||||
input_attrs[0].fmt = input_layout;
|
||||
input_attrs[0].size = inputs[0].Nbytes();
|
||||
input_attrs[0].size_with_stride = inputs[0].Nbytes();
|
||||
input_attrs[0].pass_through = 0;
|
||||
|
||||
// create input tensor memory
|
||||
rknn_tensor_mem* input_mems[1];
|
||||
input_mems[0] = rknn_create_mem(ctx, inputs[0].Nbytes());
|
||||
if (input_mems[0] == nullptr) {
|
||||
FDERROR << "rknn_create_mem input_mems error." << std::endl;
|
||||
// Create input tensor memory
|
||||
input_attrs_[i].type = input_type;
|
||||
input_attrs_[i].size = inputs[0].Nbytes();
|
||||
input_attrs_[i].size_with_stride = inputs[0].Nbytes();
|
||||
input_attrs_[i].pass_through = 0;
|
||||
input_mems_[i] = rknn_create_mem(ctx, inputs[i].Nbytes());
|
||||
if (input_mems_[i] == nullptr) {
|
||||
FDERROR << "rknn_create_mem input_mems_ error." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy input data to input tensor memory
|
||||
uint32_t width = input_attrs[0].dims[2];
|
||||
uint32_t stride = input_attrs[0].w_stride;
|
||||
if (width == stride) {
|
||||
if (inputs[0].Data() == nullptr) {
|
||||
FDERROR << "inputs[0].Data is NULL." << std::endl;
|
||||
return false;
|
||||
}
|
||||
memcpy(input_mems[0]->virt_addr, inputs[0].Data(), inputs[0].Nbytes());
|
||||
} else {
|
||||
FDERROR << "[RKNPU2Backend] only support width == stride." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create output tensor memory
|
||||
rknn_tensor_mem* output_mems[io_num.n_output];
|
||||
for (uint32_t i = 0; i < io_num.n_output; ++i) {
|
||||
// Most post-processing does not support the fp16 format.
|
||||
// The unified output here is float32
|
||||
uint32_t output_size = output_attrs[i].n_elems * sizeof(float);
|
||||
output_mems[i] = rknn_create_mem(ctx, output_size);
|
||||
}
|
||||
|
||||
// Set input tensor memory
|
||||
ret = rknn_set_io_mem(ctx, input_mems[0], &input_attrs[0]);
|
||||
ret = rknn_set_io_mem(ctx, input_mems_[i], &input_attrs_[i]);
|
||||
if (ret != RKNN_SUCC) {
|
||||
FDERROR << "input tensor memory rknn_set_io_mem fail! ret=" << ret
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Set output tensor memory
|
||||
for (uint32_t i = 0; i < io_num.n_output; ++i) {
|
||||
// Most post-processing does not support the fp16 format.
|
||||
// The unified output here is float32
|
||||
uint32_t output_size = output_attrs_[i].n_elems * sizeof(float);
|
||||
output_mems_[i] = rknn_create_mem(ctx, output_size);
|
||||
if (output_mems_[i] == nullptr) {
|
||||
FDERROR << "rknn_create_mem output_mems_ error." << std::endl;
|
||||
return false;
|
||||
}
|
||||
// default output type is depend on model, this requires float32 to compute top5
|
||||
output_attrs[i].type = RKNN_TENSOR_FLOAT32;
|
||||
ret = rknn_set_io_mem(ctx, output_mems[i], &output_attrs[i]);
|
||||
output_attrs_[i].type = RKNN_TENSOR_FLOAT32;
|
||||
ret = rknn_set_io_mem(ctx, output_mems_[i], &output_attrs_[i]);
|
||||
// set output memory and attribute
|
||||
if (ret != RKNN_SUCC) {
|
||||
FDERROR << "output tensor memory rknn_set_io_mem fail! ret=" << ret
|
||||
@@ -331,13 +351,32 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
}
|
||||
|
||||
this->infer_init = true;
|
||||
}
|
||||
|
||||
// Copy input data to input tensor memory
|
||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||
uint32_t width = input_attrs_[i].dims[2];
|
||||
uint32_t stride = input_attrs_[i].w_stride;
|
||||
if (width == stride) {
|
||||
if (inputs[i].Data() == nullptr) {
|
||||
FDERROR << "inputs[0].Data is NULL." << std::endl;
|
||||
return false;
|
||||
}
|
||||
memcpy(input_mems_[i]->virt_addr, inputs[i].Data(), inputs[i].Nbytes());
|
||||
} else {
|
||||
FDERROR << "[RKNPU2Backend] only support width == stride." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// run rknn
|
||||
ret = rknn_run(ctx, nullptr);
|
||||
if (ret != RKNN_SUCC) {
|
||||
FDERROR << "rknn run error! ret=" << ret << std::endl;
|
||||
return false;
|
||||
}
|
||||
rknn_destroy_mem(ctx, input_mems[0]);
|
||||
|
||||
// get result
|
||||
outputs->resize(outputs_desc_.size());
|
||||
@@ -349,9 +388,8 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
(*outputs)[i].Resize(temp_shape, outputs_desc_[i].dtype,
|
||||
outputs_desc_[i].name);
|
||||
memcpy((*outputs)[i].MutableData(), (float*)output_mems[i]->virt_addr,
|
||||
memcpy((*outputs)[i].MutableData(), (float*)output_mems_[i]->virt_addr,
|
||||
(*outputs)[i].Nbytes());
|
||||
rknn_destroy_mem(ctx, output_mems[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@@ -86,8 +86,13 @@ class RKNPU2Backend : public BaseBackend {
|
||||
std::vector<TensorInfo> inputs_desc_;
|
||||
std::vector<TensorInfo> outputs_desc_;
|
||||
|
||||
rknn_tensor_attr* input_attrs = nullptr;
|
||||
rknn_tensor_attr* output_attrs = nullptr;
|
||||
rknn_tensor_attr* input_attrs_ = nullptr;
|
||||
rknn_tensor_attr* output_attrs_ = nullptr;
|
||||
|
||||
rknn_tensor_mem** input_mems_;
|
||||
rknn_tensor_mem** output_mems_;
|
||||
|
||||
bool infer_init = false;
|
||||
|
||||
RKNPU2BackendOption option_;
|
||||
|
||||
|
@@ -0,0 +1,7 @@
|
||||
model_path: ./Portrait_PP_HumanSegV2_Lite_256x144_infer/Portrait_PP_HumanSegV2_Lite_256x144_infer.onnx
|
||||
output_folder: ./Portrait_PP_HumanSegV2_Lite_256x144_infer
|
||||
target_platform: RK3568
|
||||
normalize:
|
||||
mean: [[0.5,0.5,0.5]]
|
||||
std: [[0.5,0.5,0.5]]
|
||||
outputs: None
|
@@ -0,0 +1,7 @@
|
||||
model_path: ./Portrait_PP_HumanSegV2_Lite_256x144_infer/Portrait_PP_HumanSegV2_Lite_256x144_infer.onnx
|
||||
output_folder: ./Portrait_PP_HumanSegV2_Lite_256x144_infer
|
||||
target_platform: RK3588
|
||||
normalize:
|
||||
mean: [[0.5,0.5,0.5]]
|
||||
std: [[0.5,0.5,0.5]]
|
||||
outputs: None
|
@@ -1,7 +0,0 @@
|
||||
model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx
|
||||
output_folder: ./
|
||||
target_platform: RK3588
|
||||
normalize:
|
||||
mean: [0.5,0.5,0.5]
|
||||
std: [0.5,0.5,0.5]
|
||||
outputs: None
|
@@ -2,10 +2,6 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
@@ -38,9 +34,15 @@ if __name__ == "__main__":
|
||||
model = RKNN(config.verbose)
|
||||
|
||||
# Config
|
||||
mean_values = [[255 * mean for mean in yaml_config["normalize"]["mean"]]]
|
||||
std_values = [[255 * std for std in yaml_config["normalize"]["std"]]]
|
||||
model.config(mean_values=mean_values,
|
||||
if yaml_config["normalize"] == "None":
|
||||
model.config(target_platform=yaml_config["target_platform"])
|
||||
else:
|
||||
mean_values = [[256 * mean for mean in mean_ls]
|
||||
for mean_ls in yaml_config["normalize"]["mean"]]
|
||||
std_values = [[256 * std for std in std_ls]
|
||||
for std_ls in yaml_config["normalize"]["std"]]
|
||||
model.config(
|
||||
mean_values=mean_values,
|
||||
std_values=std_values,
|
||||
target_platform=yaml_config["target_platform"])
|
||||
|
||||
@@ -50,8 +52,8 @@ if __name__ == "__main__":
|
||||
if yaml_config["outputs"] == "None":
|
||||
ret = model.load_onnx(model=yaml_config["model_path"])
|
||||
else:
|
||||
ret = model.load_onnx(model=yaml_config["model_path"],
|
||||
outputs=yaml_config["outputs"])
|
||||
ret = model.load_onnx(
|
||||
model=yaml_config["model_path"], outputs=yaml_config["outputs"])
|
||||
assert ret == 0, "Load model failed!"
|
||||
|
||||
# Build model
|
||||
|
Reference in New Issue
Block a user