diff --git a/docs/cn/faq/rknpu2/export.md b/docs/cn/faq/rknpu2/export.md index 9399c78d5..6992506cf 100644 --- a/docs/cn/faq/rknpu2/export.md +++ b/docs/cn/faq/rknpu2/export.md @@ -22,8 +22,8 @@ model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx output_folder: ./ target_platform: RK3588 normalize: - mean: [0.5,0.5,0.5] - std: [0.5,0.5,0.5] + mean: [[0.5,0.5,0.5]] + std: [[0.5,0.5,0.5]] outputs: None ``` @@ -45,4 +45,4 @@ python tools/export.py --config_path=./config.yaml ## 模型导出要注意的事项 -* 请不要导出带softmax和argmax的模型,这两个算子存在bug,请在外部进行运算 \ No newline at end of file +* 请不要导出带softmax和argmax的模型,这两个算子存在bug,请在外部进行运算 diff --git a/examples/vision/detection/paddledetection/rknpu2/README.md b/examples/vision/detection/paddledetection/rknpu2/README.md new file mode 100644 index 000000000..32eff20a6 --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/README.md @@ -0,0 +1,38 @@ +# PaddleDetection RKNPU2部署示例 + +## 支持模型列表 + +目前FastDeploy支持如下模型的部署 +- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet) + +## 准备PaddleDetection部署模型以及转换模型 +RKNPU部署模型前需要将Paddle模型转换成RKNN模型,具体步骤如下: +* Paddle动态图模型转换为ONNX模型,请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md) + ,注意在转换时请设置**export.nms=True**. +* ONNX模型转换RKNN模型的过程,请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。 + + +## 模型转换example +下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。 +```bash +## 下载Paddle静态图模型并解压 +wget https://bj.bcebos.com/fastdeploy/models/rknn2/picodet_s_416_coco_npu.zip +unzip -qo picodet_s_416_coco_npu.zip + +# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐 +paddle2onnx --model_dir picodet_s_416_coco_npu \ + --model_filename model.pdmodel \ + --params_filename model.pdiparams \ + --save_file picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \ + --enable_dev_version True + +python -m paddle2onnx.optimize --input_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \ + --output_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \ + --input_shape_dict "{'image':[1,3,416,416]}" +# ONNX模型转RKNN模型 +# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下 +python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml +``` + +- [Python部署](./python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt b/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt new file mode 100644 index 000000000..b4eca78ec --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/CMakeLists.txt @@ -0,0 +1,37 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 3.10) +project(rknpu2_test) + +set(CMAKE_CXX_STANDARD 14) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake) +include_directories(${FastDeploy_INCLUDE_DIRS}) + +add_executable(infer_picodet infer_picodet.cc) +target_link_libraries(infer_picodet ${FastDeploy_LIBS}) + + + +set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install) + +install(TARGETS infer_picodet DESTINATION ./) + +install(DIRECTORY model DESTINATION ./) +install(DIRECTORY images DESTINATION ./) + +file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*) +message("${FASTDEPLOY_LIBS}") +install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib) + +file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*) +install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib) + +install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./) + +file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*) +install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib) + +file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*) +install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib) diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/README.md b/examples/vision/detection/paddledetection/rknpu2/cpp/README.md new file mode 100644 index 000000000..d0b131971 --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/README.md @@ -0,0 +1,71 @@ +# PaddleDetection C++部署示例 + +本目录下提供`infer_xxxxx.cc`快速完成PPDetection模型在Rockchip板子上上通过二代NPU加速部署的示例。 + +在部署前,需确认以下两个步骤: + +1. 软硬件环境满足要求 +2. 根据开发环境,下载预编译部署库或者从头编译FastDeploy仓库 + +以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现 + +## 生成基本目录文件 + +该例程由以下几个部分组成 +```text +. +├── CMakeLists.txt +├── build # 编译文件夹 +├── image # 存放图片的文件夹 +├── infer_cpu_npu.cc +├── infer_cpu_npu.h +├── main.cc +├── model # 存放模型文件的文件夹 +└── thirdpartys # 存放sdk的文件夹 +``` + +首先需要先生成目录结构 +```bash +mkdir build +mkdir images +mkdir model +mkdir thirdpartys +``` + +## 编译 + +### 编译并拷贝SDK到thirdpartys文件夹 + +请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成 +fastdeploy-0.0.3目录,请移动它至thirdpartys目录下. + +### 拷贝模型文件,以及配置文件至model文件夹 +在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。 +转换为RKNN后的模型文件也需要拷贝至model。 + +### 准备测试图片至image文件夹 +```bash +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg +cp 000000014439.jpg ./images +``` + +### 编译example + +```bash +cd build +cmake .. +make -j8 +make install +``` + +## 运行例程 + +```bash +cd ./build/install +./rknpu_test +``` + + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc new file mode 100644 index 000000000..297fa52e5 --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/cpp/infer_picodet.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "fastdeploy/vision.h" + +void InferPicodet(const std::string& device = "cpu"); + +int main() { + InferPicodet("npu"); + return 0; +} + +fastdeploy::RuntimeOption GetOption(const std::string& device) { + auto option = fastdeploy::RuntimeOption(); + if (device == "npu") { + option.UseRKNPU2(); + } else { + option.UseCpu(); + } + return option; +} + +fastdeploy::ModelFormat GetFormat(const std::string& device) { + auto format = fastdeploy::ModelFormat::ONNX; + if (device == "npu") { + format = fastdeploy::ModelFormat::RKNN; + } else { + format = fastdeploy::ModelFormat::ONNX; + } + return format; +} + +std::string GetModelPath(std::string& model_path, const std::string& device) { + if (device == "npu") { + model_path += "rknn"; + } else { + model_path += "onnx"; + } + return model_path; +} + +void InferPicodet(const std::string &device) { + std::string model_file = "./model/picodet_s_416_coco_npu/picodet_s_416_coco_npu_rk3588."; + std::string params_file; + std::string config_file = "./model/picodet_s_416_coco_npu/infer_cfg.yml"; + + fastdeploy::RuntimeOption option = GetOption(device); + fastdeploy::ModelFormat format = GetFormat(device); + model_file = GetModelPath(model_file, device); + auto model = fastdeploy::vision::detection::RKPicoDet( + model_file, params_file, config_file,option,format); + + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + auto image_file = "./images/000000014439.jpg"; + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + clock_t start = clock(); + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + clock_t end = clock(); + auto dur = static_cast(end - start); + printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC)); + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5); + cv::imwrite("picodet_npu_result.jpg", vis_im); + std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl; +} \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/rknpu2/python/README.md b/examples/vision/detection/paddledetection/rknpu2/python/README.md new file mode 100644 index 000000000..23b13cd3b --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/python/README.md @@ -0,0 +1,35 @@ +# PaddleDetection Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md) + +本目录下提供`infer.py`快速完成Picodet在RKNPU上部署的示例。执行如下脚本即可完成 + +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python + +# 下载图片 +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# copy model +cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python + +# 推理 +python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \ + --config_file ./picodet_s_416_coco_npu/infer_cfg.yml \ + --image 000000014439.jpg +``` + + +## 注意事项 +RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时, +需要先调用DisableNormalizePermute(C++)或`disable_normalize_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。 +## 其它文档 + +- [PaddleDetection 模型介绍](..) +- [PaddleDetection C++部署](../cpp) +- [模型预测结果说明](../../../../../../docs/api/vision_results/) +- [转换PaddleDetection RKNN模型文档](../README.md) diff --git a/examples/vision/detection/paddledetection/rknpu2/python/infer.py b/examples/vision/detection/paddledetection/rknpu2/python/infer.py new file mode 100644 index 000000000..ae2d8796a --- /dev/null +++ b/examples/vision/detection/paddledetection/rknpu2/python/infer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", required=True, help="Path of rknn model.") + parser.add_argument("--config_file", required=True, help="Path of config.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + option.use_rknpu2() + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model_file = args.model_file +params_file = "" +config_file = args.config_file +model = fd.vision.detection.RKPicoDet( + model_file, + params_file, + config_file, + runtime_option=runtime_option, + model_format=fd.ModelFormat.RKNN) + +# 预测图片分割结果 +im = cv2.imread(args.image) +result = model.predict(im.copy()) +print(result) + +# 可视化结果 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index d9ceb5dda..44054ee93 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -29,6 +29,7 @@ #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/ppdet/model.h" +#include "fastdeploy/vision/detection/contrib/rknpu2/model.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h" #include "fastdeploy/vision/facedet/contrib/scrfd.h" #include "fastdeploy/vision/facedet/contrib/ultraface.h" diff --git a/fastdeploy/vision/detection/contrib/rknpu2/model.h b/fastdeploy/vision/detection/contrib/rknpu2/model.h new file mode 100644 index 000000000..f0f8616ee --- /dev/null +++ b/fastdeploy/vision/detection/contrib/rknpu2/model.h @@ -0,0 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h" diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc b/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc new file mode 100644 index 000000000..6482ea675 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/rknpu2/rkdet_pybind.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindRKDet(pybind11::module& m) { + pybind11::class_(m, "RKPicoDet") + .def(pybind11::init()) + .def("predict", + [](vision::detection::RKPicoDet& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res); + return res; + }); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc new file mode 100644 index 000000000..926214d86 --- /dev/null +++ b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.cc @@ -0,0 +1,201 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h" +#include "yaml-cpp/yaml.h" +namespace fastdeploy { +namespace vision { +namespace detection { + +RKPicoDet::RKPicoDet(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::ORT}; + valid_rknpu_backends = {Backend::RKNPU2}; + if ((model_format == ModelFormat::RKNN) || + (model_format == ModelFormat::ONNX)) { + has_nms_ = false; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + + // NMS parameters come from RKPicoDet_s_nms + background_label = -1; + keep_top_k = 100; + nms_eta = 1; + nms_threshold = 0.5; + nms_top_k = 1000; + normalized = true; + score_threshold = 0.3; + initialized = Initialize(); +} + +bool RKPicoDet::Initialize() { + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool RKPicoDet::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + } + + Cast::Run(mat, "float"); + + scale_factor.resize(2); + scale_factor[0] = mat->Height() * 1.0 / origin_h; + scale_factor[1] = mat->Width() * 1.0 / origin_w; + + outputs->resize(1); + (*outputs)[0].name = InputInfoOfRuntime(0).name; + mat->ShareWithTensor(&((*outputs)[0])); + // reshape to [1, c, h, w] + (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); + return true; +} + +bool RKPicoDet::BuildPreprocessPipelineFromConfig() { + processors_.clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file_); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + processors_.push_back(std::make_shared()); + + for (const auto& op : cfg["Preprocess"]) { + std::string op_name = op["type"].as(); + if (op_name == "NormalizeImage") { + continue; + } else if (op_name == "Resize") { + bool keep_ratio = op["keep_ratio"].as(); + auto target_size = op["target_size"].as>(); + int interp = op["interp"].as(); + FDASSERT(target_size.size() == 2, + "Require size of target_size be 2, but now it's %lu.", + target_size.size()); + if (!keep_ratio) { + int width = target_size[1]; + int height = target_size[0]; + processors_.push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else { + int min_target_size = std::min(target_size[0], target_size[1]); + int max_target_size = std::max(target_size[0], target_size[1]); + std::vector max_size; + if (max_target_size > 0) { + max_size.push_back(max_target_size); + max_size.push_back(max_target_size); + } + processors_.push_back(std::make_shared( + min_target_size, interp, true, max_size)); + } + } else if (op_name == "Permute") { + continue; + } else if (op_name == "Pad") { + auto size = op["size"].as>(); + auto value = op["fill_value"].as>(); + processors_.push_back(std::make_shared("float")); + processors_.push_back( + std::make_shared(size[1], size[0], value)); + } else if (op_name == "PadStride") { + auto stride = op["stride"].as(); + processors_.push_back( + std::make_shared(stride, std::vector(3, 0))); + } else { + FDERROR << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + return true; +} + +bool RKPicoDet::Postprocess(std::vector& infer_result, + DetectionResult* result) { + FDASSERT(infer_result[1].shape[0] == 1, + "Only support batch = 1 in FastDeploy now."); + + if (!has_nms_) { + int boxes_index = 0; + int scores_index = 1; + if (infer_result[0].shape[1] == infer_result[1].shape[2]) { + boxes_index = 0; + scores_index = 1; + } else if (infer_result[0].shape[2] == infer_result[1].shape[1]) { + boxes_index = 1; + scores_index = 0; + } else { + FDERROR << "The shape of boxes and scores should be [batch, boxes_num, " + "4], [batch, classes_num, boxes_num]" + << std::endl; + return false; + } + + backend::MultiClassNMS nms; + nms.background_label = background_label; + nms.keep_top_k = keep_top_k; + nms.nms_eta = nms_eta; + nms.nms_threshold = nms_threshold; + nms.score_threshold = score_threshold; + nms.nms_top_k = nms_top_k; + nms.normalized = normalized; + nms.Compute(static_cast(infer_result[boxes_index].Data()), + static_cast(infer_result[scores_index].Data()), + infer_result[boxes_index].shape, + infer_result[scores_index].shape); + if (nms.out_num_rois_data[0] > 0) { + result->Reserve(nms.out_num_rois_data[0]); + } + for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) { + result->label_ids.push_back(nms.out_box_data[i * 6]); + result->scores.push_back(nms.out_box_data[i * 6 + 1]); + result->boxes.emplace_back( + std::array{nms.out_box_data[i * 6 + 2] / scale_factor[1], + nms.out_box_data[i * 6 + 3] / scale_factor[0], + nms.out_box_data[i * 6 + 4] / scale_factor[1], + nms.out_box_data[i * 6 + 5] / scale_factor[0]}); + } + } else { + FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl; + } + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h new file mode 100644 index 000000000..dbb48c16d --- /dev/null +++ b/fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/vision/detection/ppdet/ppyoloe.h" + +namespace fastdeploy { +namespace vision { +namespace detection { +class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE { + public: + RKPicoDet(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::RKNN); + + virtual std::string ModelName() const { return "RKPicoDet"; } + + protected: + /// Build the preprocess pipeline from the loaded model + virtual bool BuildPreprocessPipelineFromConfig(); + /// Preprocess an input image, and set the preprocessed results to `outputs` + virtual bool Preprocess(Mat* mat, std::vector* outputs); + + /// Postprocess the inferenced results, and set the final result to `result` + virtual bool Postprocess(std::vector& infer_result, + DetectionResult* result); + virtual bool Initialize(); + private: + std::vector scale_factor{1.0, 1.0}; +}; +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/detection/detection_pybind.cc b/fastdeploy/vision/detection/detection_pybind.cc index b3a7a6ad9..f55bf68bf 100644 --- a/fastdeploy/vision/detection/detection_pybind.cc +++ b/fastdeploy/vision/detection/detection_pybind.cc @@ -27,6 +27,8 @@ void BindNanoDetPlus(pybind11::module& m); void BindPPDet(pybind11::module& m); void BindYOLOv7End2EndTRT(pybind11::module& m); void BindYOLOv7End2EndORT(pybind11::module& m); +void BindRKDet(pybind11::module& m); + void BindDetection(pybind11::module& m) { auto detection_module = @@ -42,5 +44,6 @@ void BindDetection(pybind11::module& m) { BindNanoDetPlus(detection_module); BindYOLOv7End2EndTRT(detection_module); BindYOLOv7End2EndORT(detection_module); + BindRKDet(detection_module); } } // namespace fastdeploy diff --git a/python/fastdeploy/vision/detection/__init__.py b/python/fastdeploy/vision/detection/__init__.py index 89441f7a2..a4fe4c035 100644 --- a/python/fastdeploy/vision/detection/__init__.py +++ b/python/fastdeploy/vision/detection/__init__.py @@ -24,3 +24,4 @@ from .contrib.yolov6 import YOLOv6 from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT from .contrib.yolov7end2end_ort import YOLOv7End2EndORT from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN +from .rknpu2 import RKPicoDet diff --git a/python/fastdeploy/vision/detection/rknpu2/__init__.py b/python/fastdeploy/vision/detection/rknpu2/__init__.py new file mode 100644 index 000000000..57fcecc64 --- /dev/null +++ b/python/fastdeploy/vision/detection/rknpu2/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from typing import Union, List +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C +from .. import PPYOLOE + + +class RKPicoDet(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=ModelFormat.RKNN): + """Load a PicoDet model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g picodet/model.pdmodel + :param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + + super(PPYOLOE, self).__init__(runtime_option) + + assert model_format == ModelFormat.RKNN, "RKPicoDet model only support model format of ModelFormat.RKNN now." + self._model = C.vision.detection.RKPicoDet( + model_file, params_file, config_file, self._runtime_option, + model_format) + assert self.initialized, "RKPicoDet model initialize failed." diff --git a/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml b/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml new file mode 100644 index 000000000..985489163 --- /dev/null +++ b/tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml @@ -0,0 +1,7 @@ +model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx +output_folder: ./picodet_s_416_coco_lcnet +target_platform: RK3568 +normalize: + mean: [[0.485,0.456,0.406]] + std: [[0.229,0.224,0.225]] +outputs: ['tmp_16','p2o.Concat.9'] diff --git a/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml b/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml new file mode 100644 index 000000000..723acc8b5 --- /dev/null +++ b/tools/rknpu2/config/RK3568/picodet_s_416_coco_npu.yaml @@ -0,0 +1,5 @@ +model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx +output_folder: ./picodet_s_416_coco_npu +target_platform: RK3568 +normalize: None +outputs: ['tmp_16','p2o.Concat.17'] diff --git a/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml b/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml new file mode 100644 index 000000000..6110e8c0f --- /dev/null +++ b/tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml @@ -0,0 +1,7 @@ +model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx +output_folder: ./picodet_s_416_coco_lcnet +target_platform: RK3588 +normalize: + mean: [[0.485,0.456,0.406]] + std: [[0.229,0.224,0.225]] +outputs: ['tmp_16','p2o.Concat.9'] diff --git a/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml b/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml new file mode 100644 index 000000000..356fcfad8 --- /dev/null +++ b/tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml @@ -0,0 +1,5 @@ +model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx +output_folder: ./picodet_s_416_coco_npu +target_platform: RK3588 +normalize: None +outputs: ['tmp_16','p2o.Concat.17']