[Add Model]Add RKPicodet (#495)

* 11-02/14:35
* 新增输入数据format错误判断
* 优化推理过程,减少内存分配次数
* 支持多输入rknn模型
* rknn模型输出shape为三维时,输出将被强制对齐为4纬。现在将直接抹除rknn补充的shape,方便部分对输出shape进行判断的模型进行正确的后处理。

* 11-03/17:25
* 支持导出多输入RKNN模型
* 更新各种文档
* ppseg改用Fastdeploy中的模型进行转换

* 11-03/17:25
* 新增开源头

* 11-03/21:48
* 删除无用debug代码,补充注释

* 11-04/01:00
* 新增rkpicodet代码

* 11-04/13:13
* 提交编译缺少的文件

* 11-04/14:03
* 更新安装文档

* 11-04/14:21
* 更新picodet_s配置文件

* 11-04/14:21
* 更新picodet自适应输出结果

* 11-04/14:21
* 更新文档

* * 更新配置文件

* * 修正配置文件

* * 添加缺失的python文件

* * 修正文档

* * 修正代码格式问题0

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* * 按照要求修改

* test
This commit is contained in:
Zheng_Bicheng
2022-11-06 17:29:00 +08:00
committed by GitHub
parent 295af8f467
commit 6408af263a
19 changed files with 694 additions and 3 deletions

View File

@@ -22,8 +22,8 @@ model_path: ./portrait_pp_humansegv2_lite_256x144_pretrained.onnx
output_folder: ./
target_platform: RK3588
normalize:
mean: [0.5,0.5,0.5]
std: [0.5,0.5,0.5]
mean: [[0.5,0.5,0.5]]
std: [[0.5,0.5,0.5]]
outputs: None
```

View File

@@ -0,0 +1,38 @@
# PaddleDetection RKNPU2部署示例
## 支持模型列表
目前FastDeploy支持如下模型的部署
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
## 准备PaddleDetection部署模型以及转换模型
RKNPU部署模型前需要将Paddle模型转换成RKNN模型具体步骤如下:
* Paddle动态图模型转换为ONNX模型请参考[PaddleDetection导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md)
,注意在转换时请设置**export.nms=True**.
* ONNX模型转换RKNN模型的过程请参考[转换文档](../../../../../docs/cn/faq/rknpu2/export.md)进行转换。
## 模型转换example
下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。
```bash
## 下载Paddle静态图模型并解压
wget https://bj.bcebos.com/fastdeploy/models/rknn2/picodet_s_416_coco_npu.zip
unzip -qo picodet_s_416_coco_npu.zip
# 静态图转ONNX模型注意这里的save_file请和压缩包名对齐
paddle2onnx --model_dir picodet_s_416_coco_npu \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--save_file picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--enable_dev_version True
python -m paddle2onnx.optimize --input_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--output_model picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx \
--input_shape_dict "{'image':[1,3,416,416]}"
# ONNX模型转RKNN模型
# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_npu.yaml
```
- [Python部署](./python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,37 @@
CMAKE_MINIMUM_REQUIRED(VERSION 3.10)
project(rknpu2_test)
set(CMAKE_CXX_STANDARD 14)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake)
include_directories(${FastDeploy_INCLUDE_DIRS})
add_executable(infer_picodet infer_picodet.cc)
target_link_libraries(infer_picodet ${FastDeploy_LIBS})
set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install)
install(TARGETS infer_picodet DESTINATION ./)
install(DIRECTORY model DESTINATION ./)
install(DIRECTORY images DESTINATION ./)
file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*)
message("${FASTDEPLOY_LIBS}")
install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib)
file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*)
install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib)
install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./)
file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)

View File

@@ -0,0 +1,71 @@
# PaddleDetection C++部署示例
本目录下提供`infer_xxxxx.cc`快速完成PPDetection模型在Rockchip板子上上通过二代NPU加速部署的示例。
在部署前,需确认以下两个步骤:
1. 软硬件环境满足要求
2. 根据开发环境下载预编译部署库或者从头编译FastDeploy仓库
以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现
## 生成基本目录文件
该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_cpu_npu.cc
├── infer_cpu_npu.h
├── main.cc
├── model # 存放模型文件的文件夹
└── thirdpartys # 存放sdk的文件夹
```
首先需要先生成目录结构
```bash
mkdir build
mkdir images
mkdir model
mkdir thirdpartys
```
## 编译
### 编译并拷贝SDK到thirdpartys文件夹
请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK编译完成后将在build目录下生成
fastdeploy-0.0.3目录请移动它至thirdpartys目录下.
### 拷贝模型文件以及配置文件至model文件夹
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中将生成ONNX文件以及对应的yaml配置文件请将配置文件存放到model文件夹内。
转换为RKNN后的模型文件也需要拷贝至model。
### 准备测试图片至image文件夹
```bash
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
cp 000000014439.jpg ./images
```
### 编译example
```bash
cd build
cmake ..
make -j8
make install
```
## 运行例程
```bash
cd ./build/install
./rknpu_test
```
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"
void InferPicodet(const std::string& device = "cpu");
int main() {
InferPicodet("npu");
return 0;
}
fastdeploy::RuntimeOption GetOption(const std::string& device) {
auto option = fastdeploy::RuntimeOption();
if (device == "npu") {
option.UseRKNPU2();
} else {
option.UseCpu();
}
return option;
}
fastdeploy::ModelFormat GetFormat(const std::string& device) {
auto format = fastdeploy::ModelFormat::ONNX;
if (device == "npu") {
format = fastdeploy::ModelFormat::RKNN;
} else {
format = fastdeploy::ModelFormat::ONNX;
}
return format;
}
std::string GetModelPath(std::string& model_path, const std::string& device) {
if (device == "npu") {
model_path += "rknn";
} else {
model_path += "onnx";
}
return model_path;
}
void InferPicodet(const std::string &device) {
std::string model_file = "./model/picodet_s_416_coco_npu/picodet_s_416_coco_npu_rk3588.";
std::string params_file;
std::string config_file = "./model/picodet_s_416_coco_npu/infer_cfg.yml";
fastdeploy::RuntimeOption option = GetOption(device);
fastdeploy::ModelFormat format = GetFormat(device);
model_file = GetModelPath(model_file, device);
auto model = fastdeploy::vision::detection::RKPicoDet(
model_file, params_file, config_file,option,format);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto image_file = "./images/000000014439.jpg";
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
clock_t start = clock();
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
clock_t end = clock();
auto dur = static_cast<double>(end - start);
printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC));
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
cv::imwrite("picodet_npu_result.jpg", vis_im);
std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl;
}

View File

@@ -0,0 +1,35 @@
# PaddleDetection Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)
本目录下提供`infer.py`快速完成Picodet在RKNPU上部署的示例。执行如下脚本即可完成
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python
# 下载图片
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# copy model
cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python
# 推理
python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \
--config_file ./picodet_s_416_coco_npu/infer_cfg.yml \
--image 000000014439.jpg
```
## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式且图片归一化操作会在转RKNN模型时内嵌到模型中因此我们在使用FastDeploy部署时
需要先调用DisableNormalizePermute(C++)或`disable_normalize_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。
## 其它文档
- [PaddleDetection 模型介绍](..)
- [PaddleDetection C++部署](../cpp)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [转换PaddleDetection RKNN模型文档](../README.md)

View File

@@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of rknn model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
option.use_rknpu2()
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model_file = args.model_file
params_file = ""
config_file = args.config_file
model = fd.vision.detection.RKPicoDet(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN)
# 预测图片分割结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -29,6 +29,7 @@
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/detection/contrib/rknpu2/model.h"
#include "fastdeploy/vision/facedet/contrib/retinaface.h"
#include "fastdeploy/vision/facedet/contrib/scrfd.h"
#include "fastdeploy/vision/facedet/contrib/ultraface.h"

View File

@@ -0,0 +1,16 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindRKDet(pybind11::module& m) {
pybind11::class_<vision::detection::RKPicoDet, FastDeployModel>(m, "RKPicoDet")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::detection::RKPicoDet& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res);
return res;
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,201 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/contrib/rknpu2/rkpicodet.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace detection {
RKPicoDet::RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT};
valid_rknpu_backends = {Backend::RKNPU2};
if ((model_format == ModelFormat::RKNN) ||
(model_format == ModelFormat::ONNX)) {
has_nms_ = false;
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
// NMS parameters come from RKPicoDet_s_nms
background_label = -1;
keep_top_k = 100;
nms_eta = 1;
nms_threshold = 0.5;
nms_top_k = 1000;
normalized = true;
score_threshold = 0.3;
initialized = Initialize();
}
bool RKPicoDet::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool RKPicoDet::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
Cast::Run(mat, "float");
scale_factor.resize(2);
scale_factor[0] = mat->Height() * 1.0 / origin_h;
scale_factor[1] = mat->Width() * 1.0 / origin_w;
outputs->resize(1);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
return true;
}
bool RKPicoDet::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
continue;
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
continue;
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
return true;
}
bool RKPicoDet::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
if (!has_nms_) {
int boxes_index = 0;
int scores_index = 1;
if (infer_result[0].shape[1] == infer_result[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (infer_result[0].shape[2] == infer_result[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = background_label;
nms.keep_top_k = keep_top_k;
nms.nms_eta = nms_eta;
nms.nms_threshold = nms_threshold;
nms.score_threshold = score_threshold;
nms.nms_top_k = nms_top_k;
nms.normalized = normalized;
nms.Compute(static_cast<float*>(infer_result[boxes_index].Data()),
static_cast<float*>(infer_result[scores_index].Data()),
infer_result[boxes_index].shape,
infer_result[scores_index].shape);
if (nms.out_num_rois_data[0] > 0) {
result->Reserve(nms.out_num_rois_data[0]);
}
for (size_t i = 0; i < nms.out_num_rois_data[0]; ++i) {
result->label_ids.push_back(nms.out_box_data[i * 6]);
result->scores.push_back(nms.out_box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{nms.out_box_data[i * 6 + 2] / scale_factor[1],
nms.out_box_data[i * 6 + 3] / scale_factor[0],
nms.out_box_data[i * 6 + 4] / scale_factor[1],
nms.out_box_data[i * 6 + 5] / scale_factor[0]});
}
} else {
FDERROR << "Picodet in Backend::RKNPU2 don't support NMS" << std::endl;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/detection/ppdet/ppyoloe.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL RKPicoDet : public PPYOLOE {
public:
RKPicoDet(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::RKNN);
virtual std::string ModelName() const { return "RKPicoDet"; }
protected:
/// Build the preprocess pipeline from the loaded model
virtual bool BuildPreprocessPipelineFromConfig();
/// Preprocess an input image, and set the preprocessed results to `outputs`
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
/// Postprocess the inferenced results, and set the final result to `result`
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result);
virtual bool Initialize();
private:
std::vector<float> scale_factor{1.0, 1.0};
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -27,6 +27,8 @@ void BindNanoDetPlus(pybind11::module& m);
void BindPPDet(pybind11::module& m);
void BindYOLOv7End2EndTRT(pybind11::module& m);
void BindYOLOv7End2EndORT(pybind11::module& m);
void BindRKDet(pybind11::module& m);
void BindDetection(pybind11::module& m) {
auto detection_module =
@@ -42,5 +44,6 @@ void BindDetection(pybind11::module& m) {
BindNanoDetPlus(detection_module);
BindYOLOv7End2EndTRT(detection_module);
BindYOLOv7End2EndORT(detection_module);
BindRKDet(detection_module);
}
} // namespace fastdeploy

View File

@@ -24,3 +24,4 @@ from .contrib.yolov6 import YOLOv6
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3, MaskRCNN
from .rknpu2 import RKPicoDet

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from typing import Union, List
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
from .. import PPYOLOE
class RKPicoDet(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.RKNN):
"""Load a PicoDet model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g picodet/model.pdmodel
:param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.RKNN, "RKPicoDet model only support model format of ModelFormat.RKNN now."
self._model = C.vision.detection.RKPicoDet(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "RKPicoDet model initialize failed."

View File

@@ -0,0 +1,7 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3568
normalize:
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -0,0 +1,5 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3568
normalize: None
outputs: ['tmp_16','p2o.Concat.17']

View File

@@ -0,0 +1,7 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3588
normalize:
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -0,0 +1,5 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3588
normalize: None
outputs: ['tmp_16','p2o.Concat.17']