[Model] Add Picodet RKNPU2 (#635)

* * 更新picodet cpp代码

* * 更新文档
* 更新picodet cpp example

* * 删除无用的debug代码
* 新增python example

* * 修改c++代码

* * 修改python代码

* * 修改postprocess代码

* 修复没有scale_factor导致的bug

* 修复错误

* 更正代码格式

* 更正代码格式
This commit is contained in:
Zheng_Bicheng
2022-11-21 13:44:34 +08:00
committed by GitHub
parent 5ca779ee32
commit 3e1fc69a0c
20 changed files with 340 additions and 195 deletions

View File

@@ -13,7 +13,9 @@ RKNPU部署模型前需要将Paddle模型转换成RKNN模型具体步骤如
## 模型转换example ## 模型转换example
下面以Picodet-npu为例子,教大家如何转换PaddleDetection模型到RKNN模型。 以下步骤均在Ubuntu电脑上完成请参考配置文档完成转换模型环境配置。下面以Picodet-s为例子,教大家如何转换PaddleDetection模型到RKNN模型。
### 导出ONNX模型
```bash ```bash
# 下载Paddle静态图模型并解压 # 下载Paddle静态图模型并解压
wget https://paddledet.bj.bcebos.com/deploy/Inference/picodet_s_416_coco_lcnet.tar wget https://paddledet.bj.bcebos.com/deploy/Inference/picodet_s_416_coco_lcnet.tar
@@ -26,12 +28,89 @@ paddle2onnx --model_dir picodet_s_416_coco_lcnet \
--save_file picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ --save_file picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
--enable_dev_version True --enable_dev_version True
# 固定shape
python -m paddle2onnx.optimize --input_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ python -m paddle2onnx.optimize --input_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
--output_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \ --output_model picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx \
--input_shape_dict "{'image':[1,3,416,416]}" --input_shape_dict "{'image':[1,3,416,416]}"
```
### 编写模型导出配置文件
以转化RK3568的RKNN模型为例子我们需要编辑tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml来转换ONNX模型到RKNN模型。
**修改normalize参数**
如果你需要在NPU上执行normalize操作请根据你的模型配置normalize参数例如:
```yaml
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3568
normalize:
mean: [[0.485,0.456,0.406],[0,0,0]]
std: [[0.229,0.224,0.225],[0.003921,0.003921]]
outputs: ['tmp_17','p2o.Concat.9']
```
**修改outputs参数**
由于Paddle2ONNX版本的不同转换模型的输出节点名称也有所不同请使用[Netron](https://netron.app)并找到以下蓝色方框标记的NonMaxSuppression节点红色方框的节点名称即为目标名称。
例如使用Netron可视化后得到以下图片:
![](https://user-images.githubusercontent.com/58363586/202728663-4af0b843-d012-4aeb-8a66-626b7b87ca69.png)
找到蓝色方框标记的NonMaxSuppression节点可以看到红色方框标记的两个节点名称为tmp_17和p2o.Concat.9,因此需要修改outputs参数修改后如下:
```yaml
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3568
normalize: None
outputs: ['tmp_17','p2o.Concat.9']
```
### 转换模型
```bash
# ONNX模型转RKNN模型 # ONNX模型转RKNN模型
# 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下 # 转换模型,模型将生成在picodet_s_320_coco_lcnet_non_postprocess目录下
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml
```
### 修改模型运行时的配置文件
配置文件中,我们只需要修改**Preprocess**下的**Normalize**和**Permute**.
**删除Permute**
RKNPU只支持NHWC的输入格式因此需要删除Permute操作.删除后配置文件Precess部分后如下:
```yaml
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 416
- 416
type: Resize
- is_scale: true
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
type: NormalizeImage
```
**根据模型转换文件决定是否删除Normalize**
RKNPU支持使用NPU进行Normalize操作如果你在导出模型时配置了Normalize参数请删除**Normalize**.删除后配置文件Precess部分如下:
```yaml
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 416
- 416
type: Resize
``` ```
- [Python部署](./python) - [Python部署](./python)

View File

@@ -33,5 +33,5 @@ install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTIN
file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*) file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib) install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)
file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*) file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/${RKNN2_TARGET_SOC}/lib/*)
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib) install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)

View File

@@ -62,7 +62,7 @@ make install
```bash ```bash
cd ./build/install cd ./build/install
./rknpu_test ./infer_picodet model/picodet_s_416_coco_lcnet images/000000014439.jpg
``` ```

View File

@@ -14,73 +14,53 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "fastdeploy/vision.h" #include "fastdeploy/vision.h"
#include <sys/time.h>
double __get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
void InferPicodet(const std::string& model_dir, const std::string& image_file);
void InferPicodet(const std::string& device = "cpu"); int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./picodet_model_dir ./test.jpeg"
<< std::endl;
return -1;
}
InferPicodet(argv[1], argv[2]);
int main() {
InferPicodet("npu");
return 0; return 0;
} }
fastdeploy::RuntimeOption GetOption(const std::string& device) { void InferPicodet(const std::string& model_dir, const std::string& image_file) {
struct timeval start_time, stop_time;
auto model_file = model_dir + "/picodet_s_416_coco_lcnet_rk3568.rknn";
auto params_file = "";
auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption(); auto option = fastdeploy::RuntimeOption();
if (device == "npu") {
option.UseRKNPU2(); option.UseRKNPU2();
} else {
option.UseCpu();
}
return option;
}
fastdeploy::ModelFormat GetFormat(const std::string& device) { auto format = fastdeploy::ModelFormat::RKNN;
auto format = fastdeploy::ModelFormat::ONNX;
if (device == "npu") {
format = fastdeploy::ModelFormat::RKNN;
} else {
format = fastdeploy::ModelFormat::ONNX;
}
return format;
}
std::string GetModelPath(std::string& model_path, const std::string& device) { auto model = fastdeploy::vision::detection::PicoDet(
if (device == "npu") {
model_path += "rknn";
} else {
model_path += "onnx";
}
return model_path;
}
void InferPicodet(const std::string &device) {
std::string model_file = "./model/picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3588.";
std::string params_file;
std::string config_file = "./model/picodet_s_416_coco_lcnet/infer_cfg.yml";
fastdeploy::RuntimeOption option = GetOption(device);
fastdeploy::ModelFormat format = GetFormat(device);
model_file = GetModelPath(model_file, device);
auto model = fastdeploy::vision::detection::RKPicoDet(
model_file, params_file, config_file,option,format); model_file, params_file, config_file,option,format);
if (!model.Initialized()) { model.GetPostprocessor().ApplyDecodeAndNMS();
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto image_file = "./images/000000014439.jpg";
auto im = cv::imread(image_file); auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res; fastdeploy::vision::DetectionResult res;
clock_t start = clock(); gettimeofday(&start_time, NULL);
if (!model.Predict(&im, &res)) { if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl; std::cerr << "Failed to predict." << std::endl;
return; return;
} }
clock_t end = clock(); gettimeofday(&stop_time, NULL);
auto dur = static_cast<double>(end - start); printf("infer use %f ms\n", (__get_us(stop_time) - __get_us(start_time)) / 1000);
printf("picodet_npu use time:%f\n", (dur / CLOCKS_PER_SEC));
std::cout << res.Str() << std::endl; std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5); auto vis_im = fastdeploy::vision::VisDetection(im, res,0.5);
cv::imwrite("picodet_npu_result.jpg", vis_im); cv::imwrite("picodet_result.jpg", vis_im);
std::cout << "Visualized result saved in ./picodet_npu_result.jpg" << std::endl; std::cout << "Visualized result saved in ./picodet_result.jpg" << std::endl;
} }

View File

@@ -15,11 +15,11 @@ cd FastDeploy/examples/vision/detection/paddledetection/rknpu2/python
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# copy model # copy model
cp -r ./picodet_s_416_coco_npu /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python cp -r ./picodet_s_416_coco_lcnet /path/to/FastDeploy/examples/vision/detection/rknpu2detection/paddledetection/python
# 推理 # 推理
python3 infer.py --model_file ./picodet_s_416_coco_npu/picodet_s_416_coco_npu_3588.rknn \ python3 infer.py --model_file ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet_rk3568.rknn \
--config_file ./picodet_s_416_coco_npu/infer_cfg.yml \ --config_file ./picodet_s_416_coco_lcnet/infer_cfg.yml \
--image 000000014439.jpg --image 000000014439.jpg
``` ```

View File

@@ -28,32 +28,32 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
def build_option(args): if __name__ == "__main__":
option = fd.RuntimeOption() args = parse_arguments()
option.use_rknpu2()
return option
model_file = args.model_file
params_file = ""
config_file = args.config_file
args = parse_arguments() # 配置runtime加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_rknpu2()
# 配置runtime加载模型 model = fd.vision.detection.PicoDet(
runtime_option = build_option(args)
model_file = args.model_file
params_file = ""
config_file = args.config_file
model = fd.vision.detection.RKPicoDet(
model_file, model_file,
params_file, params_file,
config_file, config_file,
runtime_option=runtime_option, runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN) model_format=fd.ModelFormat.RKNN)
# 预测图片分割结果 model.postprocessor.apply_decode_and_nms()
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 可视化结果 # 预测图片分割结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) im = cv2.imread(args.image)
cv2.imwrite("visualized_result.jpg", vis_im) result = model.predict(im.copy())
print("Visualized result save in ./visualized_result.jpg") print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -6,10 +6,12 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace detection { namespace detection {
PPDetBase::PPDetBase(const std::string& model_file, const std::string& params_file, PPDetBase::PPDetBase(const std::string& model_file,
const std::string& params_file,
const std::string& config_file, const std::string& config_file,
const RuntimeOption& custom_option, const RuntimeOption& custom_option,
const ModelFormat& model_format) : preprocessor_(config_file) { const ModelFormat& model_format)
: preprocessor_(config_file) {
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
runtime_option.model_file = model_file; runtime_option.model_file = model_file;
@@ -37,7 +39,8 @@ bool PPDetBase::Predict(const cv::Mat& im, DetectionResult* result) {
return true; return true;
} }
bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<DetectionResult>* results) { bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<DetectionResult>* results) {
std::vector<FDMat> fd_images = WrapMat(imgs); std::vector<FDMat> fd_images = WrapMat(imgs);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl; FDERROR << "Failed to preprocess the input image." << std::endl;
@@ -46,8 +49,13 @@ bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<Detec
reused_input_tensors_[0].name = "image"; reused_input_tensors_[0].name = "image";
reused_input_tensors_[1].name = "scale_factor"; reused_input_tensors_[1].name = "scale_factor";
reused_input_tensors_[2].name = "im_shape"; reused_input_tensors_[2].name = "im_shape";
// Some models don't need im_shape as input
if (NumInputsOfRuntime() == 2) { if(postprocessor_.DecodeAndNMSApplied()){
postprocessor_.SetScaleFactor(static_cast<float*>(reused_input_tensors_[1].Data()));
}
// Some models don't need scale_factor and im_shape as input
while (reused_input_tensors_.size() != NumInputsOfRuntime()) {
reused_input_tensors_.pop_back(); reused_input_tensors_.pop_back();
} }
@@ -57,7 +65,8 @@ bool PPDetBase::BatchPredict(const std::vector<cv::Mat>& imgs, std::vector<Detec
} }
if (!postprocessor_.Run(reused_output_tensors_, results)) { if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; FDERROR << "Failed to postprocess the inference results by runtime."
<< std::endl;
return false; return false;
} }
return true; return true;

View File

@@ -38,6 +38,7 @@ class FASTDEPLOY_DECL PicoDet : public PPDetBase {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, valid_cpu_backends = {Backend::OPENVINO, Backend::ORT,
Backend::PDINFER, Backend::LITE}; Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2};
initialized = Initialize(); initialized = Initialize();
} }

View File

@@ -19,10 +19,12 @@ namespace fastdeploy {
namespace vision { namespace vision {
namespace detection { namespace detection {
bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<DetectionResult>* results) { bool PaddleDetPostprocessor::ProcessMask(
const FDTensor& tensor, std::vector<DetectionResult>* results) {
auto shape = tensor.Shape(); auto shape = tensor.Shape();
if (tensor.Dtype() != FDDataType::INT32) { if (tensor.Dtype() != FDDataType::INT32) {
FDERROR << "The data type of out mask tensor should be INT32, but now it's " << tensor.Dtype() << std::endl; FDERROR << "The data type of out mask tensor should be INT32, but now it's "
<< tensor.Dtype() << std::endl;
return false; return false;
} }
int64_t out_mask_h = shape[1]; int64_t out_mask_h = shape[1];
@@ -46,7 +48,8 @@ bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<Det
(*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w}; (*results)[i].masks[j].shape = {keep_mask_h, keep_mask_w};
const int32_t* current_ptr = data + index * out_mask_numel; const int32_t* current_ptr = data + index * out_mask_numel;
int32_t* keep_mask_ptr = reinterpret_cast<int32_t*>((*results)[i].masks[j].Data()); int32_t* keep_mask_ptr =
reinterpret_cast<int32_t*>((*results)[i].masks[j].Data());
for (int row = y1; row < y2; ++row) { for (int row = y1; row < y2; ++row) {
size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t); size_t keep_nbytes_in_col = keep_mask_w * sizeof(int32_t);
const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1; const int32_t* out_row_start_ptr = current_ptr + row * out_mask_w + x1;
@@ -59,7 +62,20 @@ bool PaddleDetPostprocessor::ProcessMask(const FDTensor& tensor, std::vector<Det
return true; return true;
} }
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vector<DetectionResult>* results) { bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results) {
if (DecodeAndNMSApplied()) {
FDASSERT(tensors.size() == 2,
"While postprocessing with ApplyDecodeAndNMS, "
"there should be 2 outputs for this model, but now it's %zu.",
tensors.size());
FDASSERT(tensors[0].shape.size() == 3,
"While postprocessing with ApplyDecodeAndNMS, "
"the rank of the first outputs should be 3, but now it's %zu",
tensors[0].shape.size());
return ProcessUnDecodeResults(tensors, results);
}
if (tensors[0].shape[0] == 0) { if (tensors[0].shape[0] == 0) {
// No detected boxes // No detected boxes
return true; return true;
@@ -69,13 +85,13 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
std::vector<int> num_boxes(tensors[1].shape[0]); std::vector<int> num_boxes(tensors[1].shape[0]);
int total_num_boxes = 0; int total_num_boxes = 0;
if (tensors[1].dtype == FDDataType::INT32) { if (tensors[1].dtype == FDDataType::INT32) {
const int32_t* data = static_cast<const int32_t*>(tensors[1].CpuData()); const auto* data = static_cast<const int32_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) { for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]); num_boxes[i] = static_cast<int>(data[i]);
total_num_boxes += num_boxes[i]; total_num_boxes += num_boxes[i];
} }
} else if (tensors[1].dtype == FDDataType::INT64) { } else if (tensors[1].dtype == FDDataType::INT64) {
const int64_t* data = static_cast<const int64_t*>(tensors[1].CpuData()); const auto* data = static_cast<const int64_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) { for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]); num_boxes[i] = static_cast<int>(data[i]);
} }
@@ -83,33 +99,37 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
// Special case for TensorRT, it has fixed output shape of NMS // Special case for TensorRT, it has fixed output shape of NMS
// So there's invalid boxes in its' output boxes // So there's invalid boxes in its' output boxes
int num_output_boxes = tensors[0].Shape()[0]; int num_output_boxes = static_cast<int>(tensors[0].Shape()[0]);
bool contain_invalid_boxes = false; bool contain_invalid_boxes = false;
if (total_num_boxes != num_output_boxes) { if (total_num_boxes != num_output_boxes) {
if (num_output_boxes % num_boxes.size() == 0) { if (num_output_boxes % num_boxes.size() == 0) {
contain_invalid_boxes = true; contain_invalid_boxes = true;
} else { } else {
FDERROR << "Cannot handle the output data for this model, unexpected situation." << std::endl; FDERROR << "Cannot handle the output data for this model, unexpected "
"situation."
<< std::endl;
return false; return false;
} }
} }
// Get boxes for each input image // Get boxes for each input image
results->resize(num_boxes.size()); results->resize(num_boxes.size());
const float* box_data = static_cast<const float*>(tensors[0].CpuData()); const auto* box_data = static_cast<const float*>(tensors[0].CpuData());
int offset = 0; int offset = 0;
for (size_t i = 0; i < num_boxes.size(); ++i) { for (size_t i = 0; i < num_boxes.size(); ++i) {
const float* ptr = box_data + offset; const float* ptr = box_data + offset;
(*results)[i].Reserve(num_boxes[i]); (*results)[i].Reserve(num_boxes[i]);
for (size_t j = 0; j < num_boxes[i]; ++j) { for (size_t j = 0; j < num_boxes[i]; ++j) {
(*results)[i].label_ids.push_back(static_cast<int32_t>(round(ptr[j * 6]))); (*results)[i].label_ids.push_back(
static_cast<int32_t>(round(ptr[j * 6])));
(*results)[i].scores.push_back(ptr[j * 6 + 1]); (*results)[i].scores.push_back(ptr[j * 6 + 1]);
(*results)[i].boxes.emplace_back(std::array<float, 4>({ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]})); (*results)[i].boxes.emplace_back(std::array<float, 4>(
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
} }
if (contain_invalid_boxes) { if (contain_invalid_boxes) {
offset += (num_output_boxes * 6 / num_boxes.size()); offset += static_cast<int>(num_output_boxes * 6 / num_boxes.size());
} else { } else {
offset += (num_boxes[i] * 6); offset += static_cast<int>(num_boxes[i] * 6);
} }
} }
@@ -119,7 +139,10 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
} }
if (tensors[2].Shape()[0] != num_output_boxes) { if (tensors[2].Shape()[0] != num_output_boxes) {
FDERROR << "The first dimension of output mask tensor:" << tensors[2].Shape()[0] << " is not equal to the first dimension of output boxes tensor:" << num_output_boxes << "." << std::endl; FDERROR << "The first dimension of output mask tensor:"
<< tensors[2].Shape()[0]
<< " is not equal to the first dimension of output boxes tensor:"
<< num_output_boxes << "." << std::endl;
return false; return false;
} }
@@ -127,6 +150,80 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors, std::vect
return ProcessMask(tensors[2], results); return ProcessMask(tensors[2], results);
} }
void PaddleDetPostprocessor::ApplyDecodeAndNMS() {
apply_decode_and_nms_ = true;
}
bool PaddleDetPostprocessor::ProcessUnDecodeResults(
const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results) {
if (tensors.size() != 2) {
return false;
}
int boxes_index = 0;
int scores_index = 1;
if (tensors[0].shape[1] == tensors[1].shape[2]) {
boxes_index = 0;
scores_index = 1;
} else if (tensors[0].shape[2] == tensors[1].shape[1]) {
boxes_index = 1;
scores_index = 0;
} else {
FDERROR << "The shape of boxes and scores should be [batch, boxes_num, "
"4], [batch, classes_num, boxes_num]"
<< std::endl;
return false;
}
backend::MultiClassNMS nms;
nms.background_label = -1;
nms.keep_top_k = 100;
nms.nms_eta = 1.0;
nms.nms_threshold = 0.5;
nms.score_threshold = 0.3;
nms.nms_top_k = 1000;
nms.normalized = true;
nms.Compute(static_cast<const float*>(tensors[boxes_index].Data()),
static_cast<const float*>(tensors[scores_index].Data()),
tensors[boxes_index].shape, tensors[scores_index].shape);
auto num_boxes = nms.out_num_rois_data;
auto box_data = static_cast<const float*>(nms.out_box_data.data());
// Get boxes for each input image
results->resize(num_boxes.size());
int offset = 0;
for (size_t i = 0; i < num_boxes.size(); ++i) {
const float* ptr = box_data + offset;
(*results)[i].Reserve(num_boxes[i]);
for (size_t j = 0; j < num_boxes[i]; ++j) {
(*results)[i].label_ids.push_back(
static_cast<int32_t>(round(ptr[j * 6])));
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
(*results)[i].boxes.emplace_back(std::array<float, 4>(
{ptr[j * 6 + 2] / GetScaleFactor()[1],
ptr[j * 6 + 3] / GetScaleFactor()[0],
ptr[j * 6 + 4] / GetScaleFactor()[1],
ptr[j * 6 + 5] / GetScaleFactor()[0]}));
}
offset += (num_boxes[i] * 6);
}
return true;
}
std::vector<float> PaddleDetPostprocessor::GetScaleFactor(){
return scale_factor_;
}
void PaddleDetPostprocessor::SetScaleFactor(float* scale_factor_value){
for (int i = 0; i < scale_factor_.size(); ++i) {
scale_factor_[i] = scale_factor_value[i];
}
}
bool PaddleDetPostprocessor::DecodeAndNMSApplied() {
return apply_decode_and_nms_;
}
} // namespace detection } // namespace detection
} // namespace vision } // namespace vision
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -33,10 +33,27 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
*/ */
bool Run(const std::vector<FDTensor>& tensors, bool Run(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* result); std::vector<DetectionResult>* result);
/// Apply box decoding and nms step for the outputs for the model.This is
/// only available for those model exported without box decoding and nms.
void ApplyDecodeAndNMS();
bool DecodeAndNMSApplied();
/// Set scale_factor_ value.This is only available for those model exported
/// without box decoding and nms.
void SetScaleFactor(float* scale_factor_value);
private: private:
// Process mask tensor for MaskRCNN // Process mask tensor for MaskRCNN
bool ProcessMask(const FDTensor& tensor, bool ProcessMask(const FDTensor& tensor,
std::vector<DetectionResult>* results); std::vector<DetectionResult>* results);
bool apply_decode_and_nms_ = false;
std::vector<float> scale_factor_{1.0, 1.0};
std::vector<float> GetScaleFactor();
bool ProcessUnDecodeResults(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results);
}; };
} // namespace detection } // namespace detection

View File

@@ -43,6 +43,10 @@ void BindPPDet(pybind11::module& m) {
} }
return results; return results;
}) })
.def("apply_decode_and_nms",
[](vision::detection::PaddleDetPostprocessor& self){
self.ApplyDecodeAndNMS();
})
.def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) { .def("run", [](vision::detection::PaddleDetPostprocessor& self, std::vector<pybind11::array>& input_array) {
std::vector<vision::DetectionResult> results; std::vector<vision::DetectionResult> results;
std::vector<FDTensor> inputs; std::vector<FDTensor> inputs;

View File

@@ -22,11 +22,13 @@ namespace vision {
namespace detection { namespace detection {
PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) { PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) {
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleDetPreprocessor."); FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
"Failed to create PaddleDetPreprocessor.");
initialized_ = true; initialized_ = true;
} }
bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) { bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(
const std::string& config_file) {
processors_.clear(); processors_.clear();
YAML::Node cfg; YAML::Node cfg;
try { try {
@@ -106,8 +108,6 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string&
// permute = cast<float> + HWC2CHW // permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float")); processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>()); processors_.push_back(std::make_shared<HWC2CHW>());
} else {
processors_.push_back(std::make_shared<HWC2CHW>());
} }
// Fusion will improve performance // Fusion will improve performance
@@ -116,13 +116,15 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig(const std::string&
return true; return true;
} }
bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) { bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images,
std::vector<FDTensor>* outputs) {
if (!initialized_) { if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl; FDERROR << "The preprocessor is not initialized." << std::endl;
return false; return false;
} }
if (images->size() == 0) { if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl; FDERROR << "The size of input images should be greater than 0."
<< std::endl;
return false; return false;
} }
@@ -140,7 +142,8 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
// All the tensor will pad to the max size to compose a batched tensor // All the tensor will pad to the max size to compose a batched tensor
std::vector<int> max_hw({-1, -1}); std::vector<int> max_hw({-1, -1});
float* scale_factor_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData()); float* scale_factor_ptr =
reinterpret_cast<float*>((*outputs)[1].MutableData());
float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData()); float* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
for (size_t i = 0; i < images->size(); ++i) { for (size_t i = 0; i < images->size(); ++i) {
int origin_w = (*images)[i].Width(); int origin_w = (*images)[i].Width();
@@ -149,7 +152,8 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
scale_factor_ptr[2 * i + 1] = 1.0; scale_factor_ptr[2 * i + 1] = 1.0;
for (size_t j = 0; j < processors_.size(); ++j) { for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) { if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl; FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl;
return false; return false;
} }
if (processors_[j]->Name().find("Resize") != std::string::npos) { if (processors_[j]->Name().find("Resize") != std::string::npos) {
@@ -174,7 +178,10 @@ bool PaddleDetPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor
// if the size of image less than max_hw, pad to max_hw // if the size of image less than max_hw, pad to max_hw
FDTensor tensor; FDTensor tensor;
(*images)[i].ShareWithTensor(&tensor); (*images)[i].ShareWithTensor(&tensor);
function::Pad(tensor, &(im_tensors[i]), {0, 0, max_hw[0] - (*images)[i].Height(), max_hw[1] - (*images)[i].Width()}, 0); function::Pad(tensor, &(im_tensors[i]),
{0, 0, max_hw[0] - (*images)[i].Height(),
max_hw[1] - (*images)[i].Width()},
0);
} else { } else {
// No need pad // No need pad
(*images)[i].ShareWithTensor(&(im_tensors[i])); (*images)[i].ShareWithTensor(&(im_tensors[i]));

View File

@@ -52,6 +52,11 @@ class PaddleDetPostprocessor:
""" """
return self._postprocessor.run(runtime_results) return self._postprocessor.run(runtime_results)
def apply_decode_and_nms(self):
"""This function will enable decode and nms in postprocess step.
"""
return self._postprocessor.apply_decode_and_nms()
class PPYOLOE(FastDeployModel): class PPYOLOE(FastDeployModel):
def __init__(self, def __init__(self,
@@ -70,7 +75,6 @@ class PPYOLOE(FastDeployModel):
""" """
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPYOLOE model only support model format of ModelFormat.Paddle now."
self._model = C.vision.detection.PPYOLOE( self._model = C.vision.detection.PPYOLOE(
model_file, params_file, config_file, self._runtime_option, model_file, params_file, config_file, self._runtime_option,
model_format) model_format)
@@ -179,7 +183,6 @@ class PicoDet(PPYOLOE):
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PicoDet model only support model format of ModelFormat.Paddle now."
self._model = C.vision.detection.PicoDet( self._model = C.vision.detection.PicoDet(
model_file, params_file, config_file, self._runtime_option, model_file, params_file, config_file, self._runtime_option,
model_format) model_format)

View File

@@ -1,44 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from typing import Union, List
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
from .. import PPYOLOE
class RKPicoDet(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.RKNN):
"""Load a PicoDet model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g picodet/model.pdmodel
:param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.RKNN, "RKPicoDet model only support model format of ModelFormat.RKNN now."
self._model = C.vision.detection.RKPicoDet(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "RKPicoDet model initialize failed."

View File

@@ -4,3 +4,4 @@ tqdm
numpy numpy
opencv-python opencv-python
fd-auto-compress>=0.0.1 fd-auto-compress>=0.0.1
pyyaml

View File

@@ -61,7 +61,8 @@ setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
"OFF") "OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF") "OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF") setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
"OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF") setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
@@ -71,13 +72,15 @@ setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF") setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF") setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF")
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
"/usr/local/cuda")
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main" setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "") setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "") setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "")
setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "") setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "")
if setup_configs["RKNN2_TARGET_SOC"] != "":
REQUIRED_PACKAGES = REQUIRED_PACKAGES.replace("opencv-python", "")
if setup_configs["WITH_GPU"] == "ON" or setup_configs[ if setup_configs["WITH_GPU"] == "ON" or setup_configs[
"BUILD_ON_JETSON"] == "ON": "BUILD_ON_JETSON"] == "ON":
@@ -105,7 +108,8 @@ extras_require = {}
# Default value is set to TRUE\1 to keep the settings same as the current ones. # Default value is set to TRUE\1 to keep the settings same as the current ones.
# However going forward the recomemded way to is to set this to False\0 # However going forward the recomemded way to is to set this to False\0
USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1') USE_MSVC_STATIC_RUNTIME = bool(
os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx') ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx')
################################################################################ ################################################################################
# Version # Version
@@ -135,7 +139,8 @@ assert CMAKE, 'Could not find "cmake" executable!'
@contextmanager @contextmanager
def cd(path): def cd(path):
if not os.path.isabs(path): if not os.path.isabs(path):
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) raise RuntimeError('Can only cd to absolute path, got: {}'.format(
path))
orig_path = os.getcwd() orig_path = os.getcwd()
os.chdir(path) os.chdir(path)
try: try:

View File

@@ -1,7 +1,5 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3568 target_platform: RK3568
normalize: normalize: None
mean: [[0.485,0.456,0.406]] outputs: ['tmp_17','p2o.Concat.9']
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -1,5 +0,0 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3568
normalize: None
outputs: ['tmp_16','p2o.Concat.17']

View File

@@ -1,7 +1,5 @@
model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx model_path: ./picodet_s_416_coco_lcnet/picodet_s_416_coco_lcnet.onnx
output_folder: ./picodet_s_416_coco_lcnet output_folder: ./picodet_s_416_coco_lcnet
target_platform: RK3588 target_platform: RK3588
normalize: normalize: None
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: ['tmp_16','p2o.Concat.9'] outputs: ['tmp_16','p2o.Concat.9']

View File

@@ -1,5 +0,0 @@
model_path: ./picodet_s_416_coco_npu/picodet_s_416_coco_npu.onnx
output_folder: ./picodet_s_416_coco_npu
target_platform: RK3588
normalize: None
outputs: ['tmp_16','p2o.Concat.17']