diff --git a/docs/cn/faq/rknpu2/rknpu2.md b/docs/cn/faq/rknpu2/rknpu2.md index bb0650860..d9e811f8d 100644 --- a/docs/cn/faq/rknpu2/rknpu2.md +++ b/docs/cn/faq/rknpu2/rknpu2.md @@ -23,6 +23,7 @@ ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型 | Segmentation | PP-HumanSegV2Lite | portrait | 133/43 | | Segmentation | PP-HumanSegV2Lite | human | 133/43 | | Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 108/42 | +| Classification | ResNet | ResNet50_vd | -/92 | ## RKNPU2 Backend推理使用教程 diff --git a/examples/vision/classification/paddleclas/rknpu2/README.md b/examples/vision/classification/paddleclas/rknpu2/README.md new file mode 100644 index 000000000..bd4305dc0 --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/README.md @@ -0,0 +1,57 @@ +# PaddleClas 模型RKNPU2部署 + +## 转换模型 +下面以 ResNet50_vd为例子,教大家如何转换分类模型到RKNN模型。 + +```bash +# 安装 paddle2onnx +pip install paddle2onnx + +# 下载ResNet50_vd模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz +tar -xvf ResNet50_vd_infer.tgz + +# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐 +paddle2onnx --model_dir ResNet50_vd_infer \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --save_file ResNet50_vd_infer/ResNet50_vd_infer.onnx \ + --enable_dev_version True \ + --opset_version 12 \ + --enable_onnx_checker True + +# 固定shape,注意这里的inputs得对应netron.app展示的 inputs 的 name,有可能是image 或者 x +python -m paddle2onnx.optimize --input_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \ + --output_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \ + --input_shape_dict "{'inputs':[1,3,224,224]}" +``` + + ### 编写模型导出配置文件 +以转化RK3588的RKNN模型为例子,我们需要编辑tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml,来转换ONNX模型到RKNN模型。 + +默认的 mean=0, std=1是在内存做normalize,如果你需要在NPU上执行normalize操作,请根据你的模型配置normalize参数,例如: +```yaml +model_path: ./ResNet50_vd_infer.onnx +output_folder: ./ +target_platform: RK3588 +normalize: + mean: [[0.485,0.456,0.406]] + std: [[0.229,0.224,0.225]] +outputs: [] +outputs_nodes: [] +do_quantization: False +dataset: +``` + + +# ONNX模型转RKNN模型 +```shell +python tools/rknpu2/export.py \ + --config_path tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml \ + --target_platform rk3588 +``` + +## 其他链接 +- [Cpp部署](./cpp) +- [Python部署](./python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) \ No newline at end of file diff --git a/examples/vision/classification/paddleclas/rknpu2/cpp/CMakeLists.txt b/examples/vision/classification/paddleclas/rknpu2/cpp/CMakeLists.txt new file mode 100644 index 000000000..6cfd9bf05 --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/cpp/CMakeLists.txt @@ -0,0 +1,37 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.10) +project(rknpu_test) + +set(CMAKE_CXX_STANDARD 14) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake) +include_directories(${FastDeploy_INCLUDE_DIRS}) +add_executable(rknpu_test infer.cc) +target_link_libraries(rknpu_test + ${FastDeploy_LIBS} + ) + + +set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install) + +install(TARGETS rknpu_test DESTINATION ./) + +install(DIRECTORY ppclas_model_dir DESTINATION ./) +install(DIRECTORY images DESTINATION ./) + +file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*) +message("${FASTDEPLOY_LIBS}") +install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib) + +file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*) +install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib) + +install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./) + +file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*) +install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib) + +file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*) +install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib) \ No newline at end of file diff --git a/examples/vision/classification/paddleclas/rknpu2/cpp/README.md b/examples/vision/classification/paddleclas/rknpu2/cpp/README.md new file mode 100644 index 000000000..1e1883486 --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/cpp/README.md @@ -0,0 +1,78 @@ +# PaddleClas C++部署示例 + +本目录下用于展示 ResNet50_vd 模型在RKNPU2上的部署,以下的部署过程以 ResNet50_vd 为例子。 + +在部署前,需确认以下两个步骤: + +1. 软硬件环境满足要求 +2. 根据开发环境,下载预编译部署库或者从头编译FastDeploy仓库 + +以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现 + +## 生成基本目录文件 + +该例程由以下几个部分组成 +```text +. +├── CMakeLists.txt +├── build # 编译文件夹 +├── images # 存放图片的文件夹 +├── infer.cc +├── ppclas_model_dir # 存放模型文件的文件夹 +└── thirdpartys # 存放sdk的文件夹 +``` + +首先需要先生成目录结构 +```bash +mkdir build +mkdir images +mkdir ppclas_model_dir +mkdir thirdpartys +``` + +## 编译 + +### 编译并拷贝SDK到thirdpartys文件夹 + +请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成 +fastdeploy-0.0.3目录,请移动它至thirdpartys目录下. + +### 拷贝模型文件,以及配置文件至model文件夹 +在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。 +转换为RKNN后的模型文件也需要拷贝至model,转换方案: ([ResNet50_vd RKNN模型](../README.md))。 + +### 准备测试图片至image文件夹 +```bash +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg +``` + +### 编译example + +```bash +cd build +cmake .. +make -j8 +make install +``` + +## 运行例程 + +```bash +cd ./build/install +./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg +``` + +## 运行结果展示 +ClassifyResult( +label_ids: 153, +scores: 0.684570, +) + +## 注意事项 +RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时, +DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。 + +## 其它文档 +- [ResNet50_vd Python 部署](../python) +- [模型预测结果说明](../../../../../../docs/api/vision_results/) +- [转换ResNet50_vd RKNN模型文档](../README.md) \ No newline at end of file diff --git a/examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc b/examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc new file mode 100755 index 000000000..fdc84dcd5 --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + "/ResNet50_vd_infer_rk3588.rknn"; + auto params_file = ""; + auto config_file = model_dir + "/inference_cls.yaml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseRKNPU2(); + + auto format = fastdeploy::ModelFormat::RKNN; + + auto model = fastdeploy::vision::classification::PaddleClasModel( + model_file, params_file, config_file,option,format); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + model.GetPreprocessor().DisablePermute(); + fastdeploy::TimeCounter tc; + tc.Start(); + auto im = cv::imread(image_file); + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // print res + std::cout << res.Str() << std::endl; + tc.End(); + tc.PrintInfo("PPClas in RKNPU2"); +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout + << "Usage: rknpu_test path/to/model_dir path/to/image run_option, " + "e.g ./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg" + << std::endl; + return -1; + } + RKNPU2Infer(argv[1], argv[2]); + return 0; +} diff --git a/examples/vision/classification/paddleclas/rknpu2/python/README.md b/examples/vision/classification/paddleclas/rknpu2/python/README.md new file mode 100644 index 000000000..b85bb81f7 --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/python/README.md @@ -0,0 +1,35 @@ +# PaddleClas Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md) + +本目录下提供`infer.py`快速完成 ResNet50_vd 在RKNPU上部署的示例。执行如下脚本即可完成 + +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/classification/paddleclas/rknpu2/python + +# 下载图片 +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# 推理 +python3 infer.py --model_file ./ResNet50_vd_infer/ResNet50_vd_infer_rk3588.rknn --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg + +# 运行完成后返回结果如下所示 +ClassifyResult( +label_ids: 153, +scores: 0.684570, +) +``` + + +## 注意事项 +RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时, +DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。 + +## 其它文档 +- [ResNet50_vd C++部署](../cpp) +- [模型预测结果说明](../../../../../../docs/api/vision_results/) +- [转换ResNet50_vd RKNN模型文档](../README.md) \ No newline at end of file diff --git a/examples/vision/classification/paddleclas/rknpu2/python/infer.py b/examples/vision/classification/paddleclas/rknpu2/python/infer.py new file mode 100644 index 000000000..92dd92c2b --- /dev/null +++ b/examples/vision/classification/paddleclas/rknpu2/python/infer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", required=True, help="Path of rknn model.") + parser.add_argument("--config_file", required=True, help="Path of config.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + model_file = args.model_file + params_file = "" + config_file = args.config_file + # 配置runtime,加载模型 + runtime_option = fd.RuntimeOption() + runtime_option.use_rknpu2() + model = fd.vision.classification.ResNet50vd( + model_file, + params_file, + config_file, + runtime_option=runtime_option, + model_format=fd.ModelFormat.RKNN) + # 禁用通道转换 + model.preprocessor.disable_permute() + im = cv2.imread(args.image) + result = model.predict(im, topk=1) + print(result) diff --git a/fastdeploy/vision/classification/ppcls/model.cc b/fastdeploy/vision/classification/ppcls/model.cc index 9d691b80b..a9b5b46f0 100755 --- a/fastdeploy/vision/classification/ppcls/model.cc +++ b/fastdeploy/vision/classification/ppcls/model.cc @@ -32,9 +32,10 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file, valid_ascend_backends = {Backend::LITE}; valid_kunlunxin_backends = {Backend::LITE}; valid_ipu_backends = {Backend::PDINFER}; - } else if (model_format == ModelFormat::ONNX) { + } else { valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_rknpu_backends = {Backend::RKNPU2}; } runtime_option = custom_option; diff --git a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc index 1873e73e5..b776d5c45 100644 --- a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc @@ -36,6 +36,12 @@ void BindPaddleClas(pybind11::module& m) { }) .def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) { self.UseGpu(gpu_id); + }) + .def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) { + self.DisableNormalize(); + }) + .def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) { + self.DisablePermute(); }); pybind11::class_( diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc index bdc21ad1e..ef8d8f20e 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.cc +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -24,19 +24,19 @@ namespace vision { namespace classification { PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) { - FDASSERT(BuildPreprocessPipelineFromConfig(config_file), + this->config_file_ = config_file; + FDASSERT(BuildPreprocessPipelineFromConfig(), "Failed to create PaddleClasPreprocessor."); initialized_ = true; } -bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig( - const std::string& config_file) { +bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig() { processors_.clear(); YAML::Node cfg; try { - cfg = YAML::LoadFile(config_file); + cfg = YAML::LoadFile(config_file_); } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file + FDERROR << "Failed to load yaml file " << config_file_ << ", maybe you should check this file." << std::endl; return false; } @@ -57,15 +57,19 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig( int height = op.begin()->second["size"].as(); processors_.push_back(std::make_shared(width, height)); } else if (op_name == "NormalizeImage") { - auto mean = op.begin()->second["mean"].as>(); - auto std = op.begin()->second["std"].as>(); - auto scale = op.begin()->second["scale"].as(); - FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, - "Only support scale in Normalize be 0.00392157, means the pixel " - "is in range of [0, 255]."); - processors_.push_back(std::make_shared(mean, std)); + if (!disable_normalize) { + auto mean = op.begin()->second["mean"].as>(); + auto std = op.begin()->second["std"].as>(); + auto scale = op.begin()->second["scale"].as(); + FDASSERT((scale - 0.00392157) < 1e-06 && (scale - 0.00392157) > -1e-06, + "Only support scale in Normalize be 0.00392157, means the pixel " + "is in range of [0, 255]."); + processors_.push_back(std::make_shared(mean, std)); + } } else if (op_name == "ToCHWImage") { - processors_.push_back(std::make_shared()); + if (!disable_permute) { + processors_.push_back(std::make_shared()); + } } else { FDERROR << "Unexcepted preprocess operator: " << op_name << "." << std::endl; @@ -78,6 +82,21 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig( return true; } +void PaddleClasPreprocessor::DisableNormalize() { + this->disable_normalize = true; + // the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + } +} +void PaddleClasPreprocessor::DisablePermute() { + this->disable_permute = true; + // the DisablePermute function will be invalid if the configuration file is loaded during preprocessing + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + } +} + void PaddleClasPreprocessor::UseGpu(int gpu_id) { #ifdef WITH_GPU use_cuda_ = true; diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.h b/fastdeploy/vision/classification/ppcls/preprocessor.h index 54c5e669d..2162ac095 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.h +++ b/fastdeploy/vision/classification/ppcls/preprocessor.h @@ -46,13 +46,24 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor { bool WithGpu() { return use_cuda_; } + /// This function will disable normalize in preprocessing step. + void DisableNormalize(); + /// This function will disable hwc2chw in preprocessing step. + void DisablePermute(); + private: - bool BuildPreprocessPipelineFromConfig(const std::string& config_file); + bool BuildPreprocessPipelineFromConfig(); std::vector> processors_; bool initialized_ = false; bool use_cuda_ = false; // GPU device id int device_id_ = -1; + // for recording the switch of hwc2chw + bool disable_permute = false; + // for recording the switch of normalize + bool disable_normalize = false; + // read config file + std::string config_file_; }; } // namespace classification diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index 91fa66c4a..b88c4361f 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -42,6 +42,18 @@ class PaddleClasPreprocessor: """ return self._preprocessor.use_gpu(gpu_id) + def disable_normalize(self): + """ + This function will disable normalize in preprocessing step. + """ + self._preprocessor.disable_normalize() + + def disable_permute(self): + """ + This function will disable hwc2chw in preprocessing step. + """ + self._preprocessor.disable_permute() + class PaddleClasPostprocessor: def __init__(self, topk=1): @@ -78,8 +90,6 @@ class PaddleClasModel(FastDeployModel): """ super(PaddleClasModel, self).__init__(runtime_option) - - assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.PADDLE now." self._model = C.vision.classification.PaddleClasModel( model_file, params_file, config_file, self._runtime_option, model_format) diff --git a/tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml b/tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml new file mode 100644 index 000000000..d47075090 --- /dev/null +++ b/tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml @@ -0,0 +1,10 @@ +model_path: ./ResNet50_vd_infer.onnx +output_folder: ./ +target_platform: RK3588 +normalize: + mean: [[0, 0, 0]] + std: [[1, 1, 1]] +outputs: [] +outputs_nodes: [] +do_quantization: False +dataset: \ No newline at end of file