mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
* [RKNPU2]support rknpu2 ClasModel #957 * [RKNPU2]support rknpu2 ClasModel #957 * [RKNPU2]support rknpu2 add Resnet50_vd example #957 * [RKNPU2]support rknpu2 add Resnet50_vd example #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957 * [RKNPU2]support rknpu2, improve doc #957
This commit is contained in:
@@ -23,6 +23,7 @@ ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型
|
|||||||
| Segmentation | PP-HumanSegV2Lite | portrait | 133/43 |
|
| Segmentation | PP-HumanSegV2Lite | portrait | 133/43 |
|
||||||
| Segmentation | PP-HumanSegV2Lite | human | 133/43 |
|
| Segmentation | PP-HumanSegV2Lite | human | 133/43 |
|
||||||
| Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 108/42 |
|
| Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 108/42 |
|
||||||
|
| Classification | ResNet | ResNet50_vd | -/92 |
|
||||||
|
|
||||||
## RKNPU2 Backend推理使用教程
|
## RKNPU2 Backend推理使用教程
|
||||||
|
|
||||||
|
57
examples/vision/classification/paddleclas/rknpu2/README.md
Normal file
57
examples/vision/classification/paddleclas/rknpu2/README.md
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# PaddleClas 模型RKNPU2部署
|
||||||
|
|
||||||
|
## 转换模型
|
||||||
|
下面以 ResNet50_vd为例子,教大家如何转换分类模型到RKNN模型。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 paddle2onnx
|
||||||
|
pip install paddle2onnx
|
||||||
|
|
||||||
|
# 下载ResNet50_vd模型文件和测试图片
|
||||||
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
|
||||||
|
tar -xvf ResNet50_vd_infer.tgz
|
||||||
|
|
||||||
|
# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐
|
||||||
|
paddle2onnx --model_dir ResNet50_vd_infer \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_file ResNet50_vd_infer/ResNet50_vd_infer.onnx \
|
||||||
|
--enable_dev_version True \
|
||||||
|
--opset_version 12 \
|
||||||
|
--enable_onnx_checker True
|
||||||
|
|
||||||
|
# 固定shape,注意这里的inputs得对应netron.app展示的 inputs 的 name,有可能是image 或者 x
|
||||||
|
python -m paddle2onnx.optimize --input_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \
|
||||||
|
--output_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \
|
||||||
|
--input_shape_dict "{'inputs':[1,3,224,224]}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 编写模型导出配置文件
|
||||||
|
以转化RK3588的RKNN模型为例子,我们需要编辑tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml,来转换ONNX模型到RKNN模型。
|
||||||
|
|
||||||
|
默认的 mean=0, std=1是在内存做normalize,如果你需要在NPU上执行normalize操作,请根据你的模型配置normalize参数,例如:
|
||||||
|
```yaml
|
||||||
|
model_path: ./ResNet50_vd_infer.onnx
|
||||||
|
output_folder: ./
|
||||||
|
target_platform: RK3588
|
||||||
|
normalize:
|
||||||
|
mean: [[0.485,0.456,0.406]]
|
||||||
|
std: [[0.229,0.224,0.225]]
|
||||||
|
outputs: []
|
||||||
|
outputs_nodes: []
|
||||||
|
do_quantization: False
|
||||||
|
dataset:
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# ONNX模型转RKNN模型
|
||||||
|
```shell
|
||||||
|
python tools/rknpu2/export.py \
|
||||||
|
--config_path tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml \
|
||||||
|
--target_platform rk3588
|
||||||
|
```
|
||||||
|
|
||||||
|
## 其他链接
|
||||||
|
- [Cpp部署](./cpp)
|
||||||
|
- [Python部署](./python)
|
||||||
|
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
|
@@ -0,0 +1,37 @@
|
|||||||
|
CMAKE_MINIMUM_REQUIRED(VERSION 3.10)
|
||||||
|
project(rknpu_test)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
|
||||||
|
# 指定下载解压后的fastdeploy库路径
|
||||||
|
set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3")
|
||||||
|
|
||||||
|
include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake)
|
||||||
|
include_directories(${FastDeploy_INCLUDE_DIRS})
|
||||||
|
add_executable(rknpu_test infer.cc)
|
||||||
|
target_link_libraries(rknpu_test
|
||||||
|
${FastDeploy_LIBS}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install)
|
||||||
|
|
||||||
|
install(TARGETS rknpu_test DESTINATION ./)
|
||||||
|
|
||||||
|
install(DIRECTORY ppclas_model_dir DESTINATION ./)
|
||||||
|
install(DIRECTORY images DESTINATION ./)
|
||||||
|
|
||||||
|
file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*)
|
||||||
|
message("${FASTDEPLOY_LIBS}")
|
||||||
|
install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib)
|
||||||
|
|
||||||
|
file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*)
|
||||||
|
install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib)
|
||||||
|
|
||||||
|
install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./)
|
||||||
|
|
||||||
|
file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
|
||||||
|
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)
|
||||||
|
|
||||||
|
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
|
||||||
|
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)
|
@@ -0,0 +1,78 @@
|
|||||||
|
# PaddleClas C++部署示例
|
||||||
|
|
||||||
|
本目录下用于展示 ResNet50_vd 模型在RKNPU2上的部署,以下的部署过程以 ResNet50_vd 为例子。
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤:
|
||||||
|
|
||||||
|
1. 软硬件环境满足要求
|
||||||
|
2. 根据开发环境,下载预编译部署库或者从头编译FastDeploy仓库
|
||||||
|
|
||||||
|
以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现
|
||||||
|
|
||||||
|
## 生成基本目录文件
|
||||||
|
|
||||||
|
该例程由以下几个部分组成
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
├── CMakeLists.txt
|
||||||
|
├── build # 编译文件夹
|
||||||
|
├── images # 存放图片的文件夹
|
||||||
|
├── infer.cc
|
||||||
|
├── ppclas_model_dir # 存放模型文件的文件夹
|
||||||
|
└── thirdpartys # 存放sdk的文件夹
|
||||||
|
```
|
||||||
|
|
||||||
|
首先需要先生成目录结构
|
||||||
|
```bash
|
||||||
|
mkdir build
|
||||||
|
mkdir images
|
||||||
|
mkdir ppclas_model_dir
|
||||||
|
mkdir thirdpartys
|
||||||
|
```
|
||||||
|
|
||||||
|
## 编译
|
||||||
|
|
||||||
|
### 编译并拷贝SDK到thirdpartys文件夹
|
||||||
|
|
||||||
|
请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成
|
||||||
|
fastdeploy-0.0.3目录,请移动它至thirdpartys目录下.
|
||||||
|
|
||||||
|
### 拷贝模型文件,以及配置文件至model文件夹
|
||||||
|
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。
|
||||||
|
转换为RKNN后的模型文件也需要拷贝至model,转换方案: ([ResNet50_vd RKNN模型](../README.md))。
|
||||||
|
|
||||||
|
### 准备测试图片至image文件夹
|
||||||
|
```bash
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
### 编译example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j8
|
||||||
|
make install
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行例程
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./build/install
|
||||||
|
./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行结果展示
|
||||||
|
ClassifyResult(
|
||||||
|
label_ids: 153,
|
||||||
|
scores: 0.684570,
|
||||||
|
)
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
|
||||||
|
DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。
|
||||||
|
|
||||||
|
## 其它文档
|
||||||
|
- [ResNet50_vd Python 部署](../python)
|
||||||
|
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
|
||||||
|
- [转换ResNet50_vd RKNN模型文档](../README.md)
|
58
examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc
Executable file
58
examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc
Executable file
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/vision.h"
|
||||||
|
|
||||||
|
void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) {
|
||||||
|
auto model_file = model_dir + "/ResNet50_vd_infer_rk3588.rknn";
|
||||||
|
auto params_file = "";
|
||||||
|
auto config_file = model_dir + "/inference_cls.yaml";
|
||||||
|
|
||||||
|
auto option = fastdeploy::RuntimeOption();
|
||||||
|
option.UseRKNPU2();
|
||||||
|
|
||||||
|
auto format = fastdeploy::ModelFormat::RKNN;
|
||||||
|
|
||||||
|
auto model = fastdeploy::vision::classification::PaddleClasModel(
|
||||||
|
model_file, params_file, config_file,option,format);
|
||||||
|
if (!model.Initialized()) {
|
||||||
|
std::cerr << "Failed to initialize." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
model.GetPreprocessor().DisablePermute();
|
||||||
|
fastdeploy::TimeCounter tc;
|
||||||
|
tc.Start();
|
||||||
|
auto im = cv::imread(image_file);
|
||||||
|
fastdeploy::vision::ClassifyResult res;
|
||||||
|
if (!model.Predict(im, &res)) {
|
||||||
|
std::cerr << "Failed to predict." << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// print res
|
||||||
|
std::cout << res.Str() << std::endl;
|
||||||
|
tc.End();
|
||||||
|
tc.PrintInfo("PPClas in RKNPU2");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
if (argc < 3) {
|
||||||
|
std::cout
|
||||||
|
<< "Usage: rknpu_test path/to/model_dir path/to/image run_option, "
|
||||||
|
"e.g ./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg"
|
||||||
|
<< std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
RKNPU2Infer(argv[1], argv[2]);
|
||||||
|
return 0;
|
||||||
|
}
|
@@ -0,0 +1,35 @@
|
|||||||
|
# PaddleClas Python部署示例
|
||||||
|
|
||||||
|
在部署前,需确认以下两个步骤
|
||||||
|
|
||||||
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)
|
||||||
|
|
||||||
|
本目录下提供`infer.py`快速完成 ResNet50_vd 在RKNPU上部署的示例。执行如下脚本即可完成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 下载部署示例代码
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
|
cd FastDeploy/examples/vision/classification/paddleclas/rknpu2/python
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
python3 infer.py --model_file ./ResNet50_vd_infer/ResNet50_vd_infer_rk3588.rknn --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
# 运行完成后返回结果如下所示
|
||||||
|
ClassifyResult(
|
||||||
|
label_ids: 153,
|
||||||
|
scores: 0.684570,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
|
||||||
|
DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。
|
||||||
|
|
||||||
|
## 其它文档
|
||||||
|
- [ResNet50_vd C++部署](../cpp)
|
||||||
|
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
|
||||||
|
- [转换ResNet50_vd RKNN模型文档](../README.md)
|
@@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import fastdeploy as fd
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_file", required=True, help="Path of rknn model.")
|
||||||
|
parser.add_argument("--config_file", required=True, help="Path of config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", type=str, required=True, help="Path of test image file.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
model_file = args.model_file
|
||||||
|
params_file = ""
|
||||||
|
config_file = args.config_file
|
||||||
|
# 配置runtime,加载模型
|
||||||
|
runtime_option = fd.RuntimeOption()
|
||||||
|
runtime_option.use_rknpu2()
|
||||||
|
model = fd.vision.classification.ResNet50vd(
|
||||||
|
model_file,
|
||||||
|
params_file,
|
||||||
|
config_file,
|
||||||
|
runtime_option=runtime_option,
|
||||||
|
model_format=fd.ModelFormat.RKNN)
|
||||||
|
# 禁用通道转换
|
||||||
|
model.preprocessor.disable_permute()
|
||||||
|
im = cv2.imread(args.image)
|
||||||
|
result = model.predict(im, topk=1)
|
||||||
|
print(result)
|
@@ -32,9 +32,10 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
|
|||||||
valid_ascend_backends = {Backend::LITE};
|
valid_ascend_backends = {Backend::LITE};
|
||||||
valid_kunlunxin_backends = {Backend::LITE};
|
valid_kunlunxin_backends = {Backend::LITE};
|
||||||
valid_ipu_backends = {Backend::PDINFER};
|
valid_ipu_backends = {Backend::PDINFER};
|
||||||
} else if (model_format == ModelFormat::ONNX) {
|
} else {
|
||||||
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
|
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
||||||
|
valid_rknpu_backends = {Backend::RKNPU2};
|
||||||
}
|
}
|
||||||
|
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
|
@@ -36,6 +36,12 @@ void BindPaddleClas(pybind11::module& m) {
|
|||||||
})
|
})
|
||||||
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
|
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
|
||||||
self.UseGpu(gpu_id);
|
self.UseGpu(gpu_id);
|
||||||
|
})
|
||||||
|
.def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) {
|
||||||
|
self.DisableNormalize();
|
||||||
|
})
|
||||||
|
.def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) {
|
||||||
|
self.DisablePermute();
|
||||||
});
|
});
|
||||||
|
|
||||||
pybind11::class_<vision::classification::PaddleClasPostprocessor>(
|
pybind11::class_<vision::classification::PaddleClasPostprocessor>(
|
||||||
|
@@ -24,19 +24,19 @@ namespace vision {
|
|||||||
namespace classification {
|
namespace classification {
|
||||||
|
|
||||||
PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) {
|
PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) {
|
||||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
|
this->config_file_ = config_file;
|
||||||
|
FDASSERT(BuildPreprocessPipelineFromConfig(),
|
||||||
"Failed to create PaddleClasPreprocessor.");
|
"Failed to create PaddleClasPreprocessor.");
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
|
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() {
|
||||||
const std::string& config_file) {
|
|
||||||
processors_.clear();
|
processors_.clear();
|
||||||
YAML::Node cfg;
|
YAML::Node cfg;
|
||||||
try {
|
try {
|
||||||
cfg = YAML::LoadFile(config_file);
|
cfg = YAML::LoadFile(config_file_);
|
||||||
} catch (YAML::BadFile& e) {
|
} catch (YAML::BadFile& e) {
|
||||||
FDERROR << "Failed to load yaml file " << config_file
|
FDERROR << "Failed to load yaml file " << config_file_
|
||||||
<< ", maybe you should check this file." << std::endl;
|
<< ", maybe you should check this file." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -57,6 +57,7 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
|
|||||||
int height = op.begin()->second["size"].as<int>();
|
int height = op.begin()->second["size"].as<int>();
|
||||||
processors_.push_back(std::make_shared<CenterCrop>(width, height));
|
processors_.push_back(std::make_shared<CenterCrop>(width, height));
|
||||||
} else if (op_name == "NormalizeImage") {
|
} else if (op_name == "NormalizeImage") {
|
||||||
|
if (!disable_normalize) {
|
||||||
auto mean = op.begin()->second["mean"].as<std::vector<float>>();
|
auto mean = op.begin()->second["mean"].as<std::vector<float>>();
|
||||||
auto std = op.begin()->second["std"].as<std::vector<float>>();
|
auto std = op.begin()->second["std"].as<std::vector<float>>();
|
||||||
auto scale = op.begin()->second["scale"].as<float>();
|
auto scale = op.begin()->second["scale"].as<float>();
|
||||||
@@ -64,8 +65,11 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
|
|||||||
"Only support scale in Normalize be 0.00392157, means the pixel "
|
"Only support scale in Normalize be 0.00392157, means the pixel "
|
||||||
"is in range of [0, 255].");
|
"is in range of [0, 255].");
|
||||||
processors_.push_back(std::make_shared<Normalize>(mean, std));
|
processors_.push_back(std::make_shared<Normalize>(mean, std));
|
||||||
|
}
|
||||||
} else if (op_name == "ToCHWImage") {
|
} else if (op_name == "ToCHWImage") {
|
||||||
|
if (!disable_permute) {
|
||||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@@ -78,6 +82,21 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PaddleClasPreprocessor::DisableNormalize() {
|
||||||
|
this->disable_normalize = true;
|
||||||
|
// the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing
|
||||||
|
if (!BuildPreprocessPipelineFromConfig()) {
|
||||||
|
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void PaddleClasPreprocessor::DisablePermute() {
|
||||||
|
this->disable_permute = true;
|
||||||
|
// the DisablePermute function will be invalid if the configuration file is loaded during preprocessing
|
||||||
|
if (!BuildPreprocessPipelineFromConfig()) {
|
||||||
|
FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void PaddleClasPreprocessor::UseGpu(int gpu_id) {
|
void PaddleClasPreprocessor::UseGpu(int gpu_id) {
|
||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
use_cuda_ = true;
|
use_cuda_ = true;
|
||||||
|
@@ -46,13 +46,24 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
|
|||||||
|
|
||||||
bool WithGpu() { return use_cuda_; }
|
bool WithGpu() { return use_cuda_; }
|
||||||
|
|
||||||
|
/// This function will disable normalize in preprocessing step.
|
||||||
|
void DisableNormalize();
|
||||||
|
/// This function will disable hwc2chw in preprocessing step.
|
||||||
|
void DisablePermute();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
|
bool BuildPreprocessPipelineFromConfig();
|
||||||
std::vector<std::shared_ptr<Processor>> processors_;
|
std::vector<std::shared_ptr<Processor>> processors_;
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
bool use_cuda_ = false;
|
bool use_cuda_ = false;
|
||||||
// GPU device id
|
// GPU device id
|
||||||
int device_id_ = -1;
|
int device_id_ = -1;
|
||||||
|
// for recording the switch of hwc2chw
|
||||||
|
bool disable_permute = false;
|
||||||
|
// for recording the switch of normalize
|
||||||
|
bool disable_normalize = false;
|
||||||
|
// read config file
|
||||||
|
std::string config_file_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace classification
|
} // namespace classification
|
||||||
|
@@ -42,6 +42,18 @@ class PaddleClasPreprocessor:
|
|||||||
"""
|
"""
|
||||||
return self._preprocessor.use_gpu(gpu_id)
|
return self._preprocessor.use_gpu(gpu_id)
|
||||||
|
|
||||||
|
def disable_normalize(self):
|
||||||
|
"""
|
||||||
|
This function will disable normalize in preprocessing step.
|
||||||
|
"""
|
||||||
|
self._preprocessor.disable_normalize()
|
||||||
|
|
||||||
|
def disable_permute(self):
|
||||||
|
"""
|
||||||
|
This function will disable hwc2chw in preprocessing step.
|
||||||
|
"""
|
||||||
|
self._preprocessor.disable_permute()
|
||||||
|
|
||||||
|
|
||||||
class PaddleClasPostprocessor:
|
class PaddleClasPostprocessor:
|
||||||
def __init__(self, topk=1):
|
def __init__(self, topk=1):
|
||||||
@@ -78,8 +90,6 @@ class PaddleClasModel(FastDeployModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
super(PaddleClasModel, self).__init__(runtime_option)
|
super(PaddleClasModel, self).__init__(runtime_option)
|
||||||
|
|
||||||
assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.PADDLE now."
|
|
||||||
self._model = C.vision.classification.PaddleClasModel(
|
self._model = C.vision.classification.PaddleClasModel(
|
||||||
model_file, params_file, config_file, self._runtime_option,
|
model_file, params_file, config_file, self._runtime_option,
|
||||||
model_format)
|
model_format)
|
||||||
|
10
tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml
Normal file
10
tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
model_path: ./ResNet50_vd_infer.onnx
|
||||||
|
output_folder: ./
|
||||||
|
target_platform: RK3588
|
||||||
|
normalize:
|
||||||
|
mean: [[0, 0, 0]]
|
||||||
|
std: [[1, 1, 1]]
|
||||||
|
outputs: []
|
||||||
|
outputs_nodes: []
|
||||||
|
do_quantization: False
|
||||||
|
dataset:
|
Reference in New Issue
Block a user