[Add Model]Add RKPicodet (#495)

* 11-02/14:35
* 新增输入数据format错误判断
* 优化推理过程,减少内存分配次数
* 支持多输入rknn模型
* rknn模型输出shape为三维时,输出将被强制对齐为4纬。现在将直接抹除rknn补充的shape,方便部分对输出shape进行判断的模型进行正确的后处理。

* 11-03/17:25
* 支持导出多输入RKNN模型
* 更新各种文档
* ppseg改用Fastdeploy中的模型进行转换

* 11-03/17:25
* 新增开源头

* 11-03/21:48
* 删除无用debug代码,补充注释

* 11-04/01:00
* 新增rkpicodet代码

* 11-04/13:13
* 提交编译缺少的文件

* 11-04/14:03
* 更新安装文档

* 11-04/14:21
* 更新picodet_s配置文件

* 11-04/14:21
* 更新picodet自适应输出结果

* 11-04/14:21
* 更新文档

* * 更新配置文件

* * 修正配置文件

* * 添加缺失的python文件

* * 修正文档

* * 修正代码格式问题0

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* test
This commit is contained in:
Zheng_Bicheng
2022-11-06 17:29:00 +08:00
committed by GitHub
parent 295af8f467
commit 6408af263a
19 changed files with 694 additions and 3 deletions

View File

@@ -22,8 +22,8 @@ model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx
output_folder: ./ output_folder: ./
target_platform: RK3588 target_platform: RK3588
normalize: normalize:
mean: [0.5,0.5,0.5] mean: [[0.5,0.5,0.5]]
std: [0.5,0.5,0.5] std: [[0.5,0.5,0.5]]
outputs: None outputs: None
``` ```

View File

@@ -0,0 +1,38 @@
# PaddleDetection RKNPU2部署示例
## 支持模型列表
目前FastDeploy支持如下模型的部署
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
## 准备PaddleDetection部署模型以及转换模型
RKNPU部署模型前需要将Paddle模型转换成RKNN模型具体步骤如下:
* Paddle动态图模型转换为ONNX模型请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md)
,注意在转换时请设置**export.nms=True**.
* ONNX模型转换RKNN模型的过程请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。
## 模型转换example
下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。
```bash
## 下载Paddle静态图模型并解压
wget https://bj.bcebos.com/fastdeploy/models/rknn2/picodet_s_416_coco_npu.zip
unzip -qo picodet_s_416_coco_npu.zip
# 静态图转ONNX模型注意这里的save_file请和压缩包名对齐
paddle2onnx --model_dir picodet_s_416_coco_npu \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--save_file picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--enable_dev_version True
python -m paddle2onnx.optimize --input_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--output_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--input_shape_dict "{'image':[1,3,416,416]}"
# ONNX模型转RKNN模型
# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml
```
- [Python部署](./python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,37 @@
CMAKE_MINIMUM_REQUIRED(VERSION 3.10)
project(rknpu2_test)
set(CMAKE_CXX_STANDARD 14)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake)
include_directories(${FastDeploy_INCLUDE_DIRS})
add_executable(infer_picodet infer_picodet.cc)
target_link_libraries(infer_picodet ${FastDeploy_LIBS})
set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install)
install(TARGETS infer_picodet DESTINATION ./)
install(DIRECTORY model DESTINATION ./)
install(DIRECTORY images DESTINATION ./)
file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*)
message("${FASTDEPLOY_LIBS}")
install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib)
file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*)
install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib)
install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./)
file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)

View File

@@ -0,0 +1,71 @@
# PaddleDetection C++部署示例
本目录下提供`infer_xxxxx.cc`快速完成PPDetection模型在Rockchip板子上上通过二代NPU加速部署的示例。
在部署前,需确认以下两个步骤:
1. 软硬件环境满足要求
2. 根据开发环境下载预编译部署库或者从头编译FastDeploy仓库
以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现
## 生成基本目录文件
该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_cpu_npu.cc
├── infer_cpu_npu.h
├── main.cc
├── model # 存放模型文件的文件夹
└── thirdpartys # 存放sdk的文件夹
```
首先需要先生成目录结构
```bash
mkdir build
mkdir images
mkdir model
mkdir thirdpartys
```
## 编译
### 编译并拷贝SDK到thirdpartys文件夹
请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK编译完成后将在build目录下生成
fastdeploy-0.0.3目录请移动它至thirdpartys目录下.
### 拷贝模型文件以及配置文件至model文件夹
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中将生成ONNX文件以及对应的yaml配置文件请将配置文件存放到model文件夹内。
转换为RKNN后的模型文件也需要拷贝至model。
### 准备测试图片至image文件夹
```bash
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
cp 000000014439.jpg ./images
```
### 编译example
```bash
cd build
cmake ..
make -j8
make install
```
## 运行例程
```bash
cd ./build/install
./rknpu_test
```
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"
void InferPicodet(const std::string& device = "cpu");
int main() {
InferPicodet("npu");
return 0;
}
fastdeploy::RuntimeOption GetOption(const std::string& device) {
auto option = fastdeploy::RuntimeOption();
if (device == "npu") {
option.UseRKNPU2();
} else {
option.UseCpu();
}
return option;
}
fastdeploy::ModelFormat GetFormat(const std::string& device) {
auto format = fastdeploy::ModelFormat::ONNX;
if (device == "npu") {
format = fastdeploy::ModelFormat::RKNN;
} else {
format = fastdeploy::ModelFormat::ONNX;
}
return format;
}
std::string GetModelPath(std::string& model_path, const std::string& device) {
if (device == "npu") {
model_path += "rknn";
} else {
model_path += "onnx";
}
return model_path;
}
void InferPicodet(const std::string &device) {
std::string model_file = "./model/picodet_s_416_coco_npu/picodet_s_416_coco_npu_rk3588.";
std::string params_file;
std::string config_file = "./model/picodet_s_416_coco_npu/infer_cfg.yml";
fastdeploy::RuntimeOption option = GetOption(device);
fastdeploy::ModelFormat format = GetFormat(device);
model_file = GetModelPath(model_file, device);
auto model = fastdeploy::vision::detection::RKPicoDet(
model_file, params_file, config_file,option,format);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto image_file = "./images/000000014439.jpg";
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
clock_t start = clock();
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
clock_t end = clock();
auto dur = static_cast<double>(end - start);
printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC));
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
cv::imwrite("picodet_npu_result.jpg", vis_im);
std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl;
}

View File

@@ -0,0 +1,35 @@
# PaddleDetection Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)
本目录下提供`infer.py`快速完成Picodet在RKNPU上部署的示例。执行如下脚本即可完成
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python
# 下载图片
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# copy model
cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python
# 推理
python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \
--config_file ./picodet_s_416_coco_npu/infer_cfg.yml \
--image 000000014439.jpg
```
## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式且图片归一化操作会在转RKNN模型时内嵌到模型中因此我们在使用FastDeploy部署时
需要先调用DisableNormalizePermute(C++)或`disable_normalize_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。
## 其它文档
- [PaddleDetection 模型介绍](..)
- [PaddleDetection C++部署](../cpp)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [转换PaddleDetection RKNN模型文档](../README.md)

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of rknn model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
option.use_rknpu2()
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model_file = args.model_file
params_file = ""
config_file = args.config_file
model = fd.vision.detection.RKPicoDet(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN)
# 预测图片分割结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -29,6 +29,7 @@
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/detection/contrib/rknpu2/model.h"
#include "fastdeploy/vision/facedet/contrib/retinaface.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h"
#include "fastdeploy/vision/facedet/contrib/scrfd.h" #include "fastdeploy/vision/facedet/contrib/scrfd.h"
#include "fastdeploy/vision/facedet/contrib/ultraface.h" #include "fastdeploy/vision/facedet/contrib/ultraface.h"

View File

@@ -0,0 +1,16 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindRKDet(pybind11::module& m) {
pybind11::class_<vision::detection::RKPicoDet, FastDeployModel>(m, "RKPicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::detection::RKPicoDet& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,201 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
RKPicoDet::RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT};
valid_rknpu_backends = {Backend::RKNPU2};
if ((model_format == ModelFormat::RKNN) ||
(model_format == ModelFormat::ONNX)) {
has_nms_ = false;
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
// NMS parameters come from RKPicoDet_s_nms
background_label = -1;
keep_top_k = 100;
nms_eta = 1;
nms_threshold = 0.5;
nms_top_k = 1000;
normalized = true;
score_threshold = 0.3;
initialized = Initialize();
}
bool RKPicoDet::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool RKPicoDet::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
Cast::Run(mat, "float");
scale_factor.resize(2);
scale_factor[0] = mat->Height() * 1.0 / origin_h;
scale_factor[1] = mat->Width() * 1.0 / origin_w;
outputs->resize(1);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
return true;
}
bool RKPicoDet::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
continue;
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
continue;
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
return true;
}
bool RKPicoDet::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
if (!has_nms_) {
int boxes_index = 0;
int scores_index = 1;
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = background_label;
nms.keep_top_k = keep_top_k;
nms.nms_eta = nms_eta;
nms.nms_threshold = nms_threshold;
nms.score_threshold = score_threshold;
nms.nms_top_k = nms_top_k;
nms.normalized = normalized;
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
static_cast<float*>(infer_result[scores_index].Data()),
infer_result[boxes_index].shape,
infer_result[scores_index].shape);
if (nms.out_num_rois_data[0] > 0) {
result->Reserve(nms.out_num_rois_data[0]);
}
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
result->label_ids.push_back(nms.out_box_data[i * 6]);
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{nms.out_box_data[i * 6 + 2] / scale_factor[1],
nms.out_box_data[i * 6 + 3] / scale_factor[0],
nms.out_box_data[i * 6 + 4] / scale_factor[1],
nms.out_box_data[i * 6 + 5] / scale_factor[0]});
}
} else {
FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE {
public:
RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::RKNN);
virtual std::string ModelName() const { return "RKPicoDet"; }
protected:
/// Build the preprocess pipeline from the loaded model
virtual bool BuildPreprocessPipelineFromConfig();
/// Preprocess an input image, and set the preprocessed results to `outputs`
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
/// Postprocess the inferenced results, and set the final result to `result`
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result);
virtual bool Initialize();
private:
std::vector<float> scale_factor{1.0, 1.0};
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -27,6 +27,8 @@ void BindNanoDetPlus(pybind11::module& m);
void BindPPDet(pybind11::module& m); void BindPPDet(pybind11::module& m);
void BindYOLOv7End2EndTRT(pybind11::module& m); void BindYOLOv7End2EndTRT(pybind11::module& m);
void BindYOLOv7End2EndORT(pybind11::module& m); void BindYOLOv7End2EndORT(pybind11::module& m);
void BindRKDet(pybind11::module& m);
void BindDetection(pybind11::module& m) { void BindDetection(pybind11::module& m) {
auto detection_module = auto detection_module =
@@ -42,5 +44,6 @@ void BindDetection(pybind11::module& m) {
BindNanoDetPlus(detection_module); BindNanoDetPlus(detection_module);
BindYOLOv7End2EndTRT(detection_module); BindYOLOv7End2EndTRT(detection_module);
BindYOLOv7End2EndORT(detection_module); BindYOLOv7End2EndORT(detection_module);
BindRKDet(detection_module);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -24,3 +24,4 @@ from .contrib.yolov6 import YOLOv6
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN
from .rknpu2 import RKPicoDet

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from typing import Union, List
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
from .. import PPYOLOE
class RKPicoDet(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.RKNN):
"""Load a PicoDet model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g picodet/model.pdmodel
:param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.RKNN, "RKPicoDet model only support model format of ModelFormat.RKNN now."
self._model = C.vision.detection.RKPicoDet(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "RKPicoDet model initialize failed."

View File

@@ -0,0 +1,7 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3568
normalize:
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -0,0 +1,5 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3568
normalize: None
outputs: ['tmp_16','p2o.Concat.17']

View File

@@ -0,0 +1,7 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3588
normalize:
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -0,0 +1,5 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3588
normalize: None
outputs: ['tmp_16','p2o.Concat.17']