[Hackthon_4th 177] Support PP-YOLOE-R with BM1684 (#1809)

* first draft

* add robx iou

* add benchmark for ppyoloe_r

* remove trash code

* fix bugs

* add pybind nms rotated option

* add missing head file

* fix bug

* fix bug2

* fix shape bug

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
thunder95
2023-04-21 10:48:05 +08:00
committed by GitHub
parent f3d44785c4
commit 51be3fea78
31 changed files with 1389 additions and 6 deletions

View File

@@ -22,6 +22,8 @@ add_executable(benchmark_ppmatting ${PROJECT_SOURCE_DIR}/benchmark_ppmatting.cc)
add_executable(benchmark_ppocr_det ${PROJECT_SOURCE_DIR}/benchmark_ppocr_det.cc)
add_executable(benchmark_ppocr_cls ${PROJECT_SOURCE_DIR}/benchmark_ppocr_cls.cc)
add_executable(benchmark_ppocr_rec ${PROJECT_SOURCE_DIR}/benchmark_ppocr_rec.cc)
add_executable(benchmark_ppyoloe_r ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r.cc)
add_executable(benchmark_ppyoloe_r_sophgo ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r_sophgo.cc)
add_executable(benchmark_ppyolo ${PROJECT_SOURCE_DIR}/benchmark_ppyolo.cc)
add_executable(benchmark_yolov3 ${PROJECT_SOURCE_DIR}/benchmark_yolov3.cc)
add_executable(benchmark_fasterrcnn ${PROJECT_SOURCE_DIR}/benchmark_fasterrcnn.cc)
@@ -44,6 +46,8 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_ppyolov8 ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppyolox ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppyoloe ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppyoloe_r ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppyoloe_r_sophgo ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_picodet ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppseg ${FASTDEPLOY_LIBS} gflags pthread)
@@ -72,6 +76,8 @@ else()
target_link_libraries(benchmark_ppyolov8 ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppyolox ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppyoloe ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppyoloe_r ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppyoloe_r_sophgo ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_picodet ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppseg ${FASTDEPLOY_LIBS} gflags)

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "flags.h"
#include "macros.h"
#include "option.h"
namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;
DEFINE_bool(no_nms, false, "Whether the model contains nms.");
int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
auto im = cv::imread(FLAGS_image);
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::PADDLE;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
auto config_file = FLAGS_model + sep + config_name;
auto model_ppyoloe_r = vision::detection::PPYOLOER(
model_file, params_file, config_file, option, model_format);
vision::DetectionResult res;
// Run profiling
BENCHMARK_MODEL(model_ppyoloe_r, model_ppyoloe_r.Predict(im, &res))
auto vis_im = vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
#endif
return 0;
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "flags.h"
#include "macros.h"
#include "option.h"
namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;
DEFINE_bool(no_nms, false, "Whether the model contains nms.");
int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
auto im = cv::imread(FLAGS_image);
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::SOPHGO;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
auto config_file = FLAGS_model + sep + config_name;
auto model_ppyoloe_r = vision::detection::PPYOLOER(
model_file, params_file, config_file, option, model_format);
vision::DetectionResult res;
// Run profiling
BENCHMARK_MODEL(model_ppyoloe_r, model_ppyoloe_r.Predict(im, &res))
auto vis_im = vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
#endif
return 0;
}

View File

@@ -158,5 +158,14 @@ static bool UpdateModelResourceName(
return false;
}
}
if (config_info["backend"] == "sophgo") {
*model_format = fastdeploy::ModelFormat::SOPHGO;
if (!GetModelResoucesNameFromDir(FLAGS_model, model_name, "bmodel")) {
std::cout << "Can not find sophgo model resources." << std::endl;
return false;
}
}
return true;
}

View File

@@ -103,6 +103,9 @@ static bool CreateRuntimeOption(fastdeploy::RuntimeOption* option,
if (config_info["use_fp16"] == "true") {
option->paddle_lite_option.enable_fp16 = true;
}
} else if (config_info["backend"] == "sophgo") {
option->UseSophgo();
option->UseSophgoBackend();
} else if (config_info["backend"] == "default") {
PrintBenchmarkInfo(config_info);
return true;

View File

@@ -50,6 +50,7 @@ typedef struct FD_C_OneDimMask {
typedef struct FD_C_DetectionResult {
FD_C_TwoDimArrayFloat boxes;
FD_C_TwoDimArrayFloat rotated_boxes;
FD_C_OneDimArrayFloat scores;
FD_C_OneDimArrayInt32 label_ids;
FD_C_OneDimMask masks;

View File

@@ -135,6 +135,7 @@ public struct FD_OneDimMask {
[StructLayout(LayoutKind.Sequential)]
public struct FD_DetectionResult {
public FD_TwoDimArrayFloat boxes;
public FD_TwoDimArrayFloat rotated_boxes;
public FD_OneDimArrayFloat scores;
public FD_OneDimArrayInt32 label_ids;
public FD_OneDimMask masks;

View File

@@ -10,6 +10,7 @@ English | [简体中文](README_CN.md)
Now FastDeploy supports the deployment of the following models
- [PP-YOLOE(including PP-YOLOE+) models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
- [PP-YOLOE-R models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
- [PicoDet models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
- [PP-YOLO models(including v2)](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyolo)
- [YOLOv3 models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3)
@@ -43,7 +44,7 @@ Before deployment, PaddleDetection needs to be exported into the deployment mode
## Download Pre-trained Model
For developers' testing, models exported by PaddleDetection are provided below. Developers can download them directly.
For developers' testing, models exported by PaddleDetection are provided below. Developers can download them directly.
The accuracy metric is from model descriptions in PaddleDetection. Refer to them for details.

View File

@@ -10,6 +10,7 @@
目前FastDeploy支持如下模型的部署
- [PP-YOLOE(含PP-YOLOE+)系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
- [PP-YOLOE-R系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
- [PP-YOLO系列模型(含v2)](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyolo)
- [YOLOv3系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3)

View File

@@ -15,6 +15,9 @@ target_link_libraries(infer_faster_rcnn_demo ${FASTDEPLOY_LIBS})
add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS})
add_executable(infer_ppyoloe_r_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe_r.cc)
target_link_libraries(infer_ppyoloe_r_demo ${FASTDEPLOY_LIBS})
add_executable(infer_picodet_demo ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
target_link_libraries(infer_picodet_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,98 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseCpu();
auto model = fastdeploy::vision::detection::PPYOLOER(model_file, params_file,
config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
std::cout << im.cols << " vs " << im.rows << std::endl;
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::detection::PPYOLOER(model_file, params_file,
config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
const cv::Mat im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.1);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_ppyoloe_r path/to/model_dir path/to/image run_option, "
"e.g ./infer_ppyoloe_r ./ppyoloe_model_dir ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend; 3: run "
"with kunlunxin."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -0,0 +1,78 @@
import cv2
import os
import fastdeploy as fd
print(111)
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
default=None,
help="Path of PaddleDetection model directory")
parser.add_argument(
"--image", default=None, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'kunlunxin', 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "kunlunxin":
option.use_kunlunxin()
if args.device.lower() == "ascend":
option.use_ascend()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
return option
args = parse_arguments()
if args.model_dir is None:
model_dir = fd.download_model(name='ppyoloe_crn_l_300e_coco')
else:
model_dir = args.model_dir
model_file = os.path.join(model_dir, "model.pdmodel")
params_file = os.path.join(model_dir, "model.pdiparams")
config_file = os.path.join(model_dir, "infer_cfg.yml")
# 配置runtime加载模型
runtime_option = build_option(args)
print(args)
model = fd.vision.detection.PPYOLOER(
model_file, params_file, config_file, runtime_option=runtime_option)
print(2222)
# 预测图片检测结果
if args.image is None:
image = fd.utils.get_detection_test_image()
else:
image = args.image
im = cv2.imread(image)
result = model.predict(im)
print(result)
# 预测结果可视化
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -6,6 +6,7 @@
- [PP-YOLOE系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
- [YOLOV8系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4)
- [PP-YOLOE-R系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
## 准备PP-YOLOE YOLOV8或者PicoDet部署模型以及转换模型

View File

@@ -13,9 +13,11 @@ include_directories(${FASTDEPLOY_INCS})
include_directories(${FastDeploy_INCLUDE_DIRS})
add_executable(infer_ppyoloe ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
add_executable(infer_ppyoloe_r ${PROJECT_SOURCE_DIR}/infer_ppyoloe_r.cc)
add_executable(infer_picodet ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
add_executable(infer_yolov8 ${PROJECT_SOURCE_DIR}/infer_yolov8.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_ppyoloe ${FASTDEPLOY_LIBS})
target_link_libraries(infer_ppyoloe_r ${FASTDEPLOY_LIBS})
target_link_libraries(infer_picodet ${FASTDEPLOY_LIBS})
target_link_libraries(infer_yolov8 ${FASTDEPLOY_LIBS})

View File

@@ -1,6 +1,6 @@
# PaddleDetection C++部署示例
本目录下提供`infer_ppyoloe.cc`,`infer_yolov8.cc``infer_picodet.cc`快速完成PP-YOLOE模型,YOLOV8模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
本目录下提供`infer_ppyoloe.cc`,`infer_ppyoloe_r.cc`,`infer_yolov8.cc``infer_picodet.cc`快速完成PP-YOLOE模型,PP-YOLOE-R模型,YOLOV8模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
在部署前,需确认以下两个步骤:
@@ -18,6 +18,7 @@
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_ppyoloe.cc
├── infer_ppyoloe_r.cc
├── infer_picodet.cc
├── infer_yolov8.cc
└── model # 存放模型文件的文件夹
@@ -37,6 +38,9 @@
```bash
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
cp 000000014439.jpg ./images
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png
cp P0861__1.0__1154___824.png ./images
```
### 编译example
@@ -53,6 +57,9 @@ make
#ppyoloe推理示例
./infer_ppyoloe model images/000000014439.jpg
#ppyoloe_r推理示例
./infer_ppyoloe_r model images/P0861__1.0__1154___824.png
#picodet推理示例
./infer_picodet model images/000000014439.jpg
```

View File

@@ -0,0 +1,58 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/time.h>
#include <iostream>
#include <string>
#include "fastdeploy/vision.h"
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel";
auto params_file = "";
auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseSophgo();
option.UseSophgoBackend();
auto format = fastdeploy::ModelFormat::SOPHGO;
auto model = fastdeploy::vision::detection::PPYOLOER(
model_file, params_file, config_file, option, format);
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.1);
cv::imwrite("infer_sophgo.jpg", vis_im);
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./model_dir ./test.jpeg"
<< std::endl;
return -1;
}
SophgoInfer(argv[1], argv[2]);
return 0;
}

View File

@@ -13,6 +13,7 @@ cd FastDeploy/examples/vision/detection/paddledetection/sophgo/python
# 下载图片
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png
# 推理
# ppyoloe推理示例
@@ -33,12 +34,19 @@ python3 infer_picodet.py --auto False --pp_detect_path '' --model_file model/ppy
python3 infer_yolov8.py --model_file model/yolov8s_s_300e_coco_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
# 运行完成后返回结果如下所示
可视化结果存储在sophgo_result.jpg中
# ppyoloe_r推理示例
# 指定--auto True自动完成模型准备、转换和推理需要指定PaddleDetection路径
python3 infer_ppyoloe_r.py --model_file model/ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel --image P0861__1.0__1154___824.png --config_file model/infer_cfg.yml
可视化结果存储在sophgo_result_ppyoloe_r.jpg中
```
## 其它文档
- [PP-YOLOE C++部署](../cpp)
- [PicoDet C++部署](../cpp)
- [YOLOV8 C++部署](../cpp)
- [PP-YOLOE-R C++部署](../cpp)
- [转换PicoDet SOPHGO模型文档](../README.md)
- [转换PP-YOLOE SOPHGO模型文档](../README.md)
- [转换YOLOV8 SOPHGO模型文档](../README.md)
- [转换PP-YOLOE-R SOPHGO模型文档](../README.md)

View File

@@ -0,0 +1,159 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
from subprocess import run
from prepare_npz import prepare
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--auto",
action="store_true",
help="Auto download, convert, compile and infer if True")
parser.add_argument(
"--pp_detect_path",
default='/workspace/PaddleDetection',
help="Path of PaddleDetection folder")
parser.add_argument(
"--model_file", required=True, help="Path of sophgo model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()
def export_model(args):
PPDetection_path = args.pp_detect_path
export_str = 'python3 tools/export_model.py \
-c configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml \
-output_dir=output_inference \
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams'
cur_path = os.getcwd()
os.chdir(PPDetection_path)
print(export_str)
run(export_str, shell=True)
cp_str = 'cp -r ./output_inference/ppyoloe_crn_s_300e_coco ' + cur_path
print(cp_str)
run(cp_str, shell=True)
os.chdir(cur_path)
def paddle2onnx():
convert_str = 'paddle2onnx --model_dir ppyoloe_r_crn_s_3x_dota \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--save_file ppyoloe_r_crn_s_3x_dota.onnx \
--enable_dev_version True'
print(convert_str)
run(convert_str, shell=True)
def mlir_prepare():
mlir_path = os.getenv("MODEL_ZOO_PATH")
mlir_path = mlir_path[:-13]
regression_path = os.path.join(mlir_path, 'regression')
mv_str_list = [
'mkdir ppyoloe_r', 'cp -rf ' + os.path.join(
regression_path, 'dataset/COCO2017/') + ' ./ppyoloe_r',
'cp -rf ' + os.path.join(regression_path, 'image/') + ' ./ppyoloe_r',
'cp ppyoloe_r_crn_s_3x_dota.onnx ./ppyoloe_r',
'mkdir ./ppyoloe_r/workspace'
]
for str in mv_str_list:
print(str)
run(str, shell=True)
def image_prepare():
img_str = 'wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png'
if not os.path.exists('P0861__1.0__1154___824.png'):
print(img_str)
run(img_str, shell=True)
prepare('P0861__1.0__1154___824.png', [640, 640])
cp_npz_str = 'cp ./inputs.npz ./ppyoloe_r'
print(cp_npz_str)
run(cp_npz_str, shell=True)
def onnx2mlir():
transform_str = 'model_transform.py \
--model_name ppyoloe_r_crn_s_3x_dota \
--model_def ../ppyoloe_r_crn_s_3x_dota.onnx \
--input_shapes [[1,3,1024,1024],[1,2]] \
--keep_aspect_ratio \
--pixel_format rgb \
--mlir ppyoloe_r_crn_s_3x_dota.mlir'
os.chdir('./ppyoloe_r/workspace')
print(transform_str)
run(transform_str, shell=True)
os.chdir('../../')
def mlir2bmodel():
deploy_str = 'model_deploy.py \
--mlir ppyoloe_r_crn_s_3x_dota.mlir \
--quantize F32 \
--chip bm1684x \
--model ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel'
os.chdir('./ppyoloe_r/workspace')
print(deploy_str)
run(deploy_str, shell=True)
os.chdir('../../')
if __name__ == "__main__":
args = parse_arguments()
if args.auto:
export_model(args)
paddle2onnx()
mlir_prepare()
image_prepare()
onnx2mlir()
mlir2bmodel()
model_file = './ppyoloe/workspace/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel' if args.auto else args.model_file
params_file = ""
config_file = './ppyoloe_r_crn_s_3x_dota/infer_cfg.yml' if args.auto else args.config_file
image_file = './P0861__1.0__1154___824.png' if args.auto else args.image
# 配置runtime加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_sophgo()
model = fd.vision.detection.PPYOLOER(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.SOPHGO)
# 预测图片分割结果
im = cv2.imread(image_file)
result = model.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.1)
cv2.imwrite("sophgo_result_ppyoloe_r.jpg", vis_im)
print("Visualized result save in ./sophgo_result_ppyoloe_r.jpg")

View File

@@ -176,6 +176,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
bm_tensor_t input_tensors[input_size];
bm_status_t status = BM_SUCCESS;
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
bm_data_type_t* input_dtypes = net_info_->input_dtypes;
for (int i = 0; i < input_size; i++) {
status = bm_malloc_device_byte(handle_, &input_tensors[i].device_mem,
@@ -198,12 +199,14 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
assert(BM_SUCCESS == status);
}
RUNTIME_PROFILE_LOOP_BEGIN(1)
bool launch_status = bmrt_launch_tensor_ex(
p_bmrt_, net_name_.c_str(), input_tensors, net_info_->input_num,
output_tensors, net_info_->output_num, true, false);
assert(launch_status);
status = bm_thread_sync(handle_);
assert(status == BM_SUCCESS);
RUNTIME_PROFILE_LOOP_END
outputs->resize(outputs_desc_.size());
bm_data_type_t* output_dtypes = net_info_->output_dtypes;
@@ -231,6 +234,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
for (int i = 0; i < output_size; i++) {
bm_free_device(handle_, output_tensors[i].device_mem);
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

View File

@@ -83,6 +83,7 @@ std::string Mask::Str() {
DetectionResult::DetectionResult(const DetectionResult& res) {
boxes.assign(res.boxes.begin(), res.boxes.end());
rotated_boxes.assign(res.rotated_boxes.begin(), res.rotated_boxes.end());
scores.assign(res.scores.begin(), res.scores.end());
label_ids.assign(res.label_ids.begin(), res.label_ids.end());
contain_masks = res.contain_masks;
@@ -98,6 +99,7 @@ DetectionResult::DetectionResult(const DetectionResult& res) {
DetectionResult& DetectionResult::operator=(DetectionResult&& other) {
if (&other != this) {
boxes = std::move(other.boxes);
rotated_boxes = std::move(other.rotated_boxes);
scores = std::move(other.scores);
label_ids = std::move(other.label_ids);
contain_masks = std::move(other.contain_masks);
@@ -111,6 +113,7 @@ DetectionResult& DetectionResult::operator=(DetectionResult&& other) {
void DetectionResult::Free() {
std::vector<std::array<float, 4>>().swap(boxes);
std::vector<std::array<float, 8>>().swap(rotated_boxes);
std::vector<float>().swap(scores);
std::vector<int32_t>().swap(label_ids);
std::vector<Mask>().swap(masks);
@@ -119,6 +122,7 @@ void DetectionResult::Free() {
void DetectionResult::Clear() {
boxes.clear();
rotated_boxes.clear();
scores.clear();
label_ids.clear();
masks.clear();
@@ -127,6 +131,7 @@ void DetectionResult::Clear() {
void DetectionResult::Reserve(int size) {
boxes.reserve(size);
rotated_boxes.reserve(size);
scores.reserve(size);
label_ids.reserve(size);
if (contain_masks) {
@@ -136,6 +141,7 @@ void DetectionResult::Reserve(int size) {
void DetectionResult::Resize(int size) {
boxes.resize(size);
rotated_boxes.resize(size);
scores.resize(size);
label_ids.resize(size);
if (contain_masks) {
@@ -163,6 +169,19 @@ std::string DetectionResult::Str() {
out += ", " + masks[i].Str();
}
}
for (size_t i = 0; i < rotated_boxes.size(); ++i) {
out = out + std::to_string(rotated_boxes[i][0]) + "," +
std::to_string(rotated_boxes[i][1]) + ", " +
std::to_string(rotated_boxes[i][2]) + ", " +
std::to_string(rotated_boxes[i][3]) + ", " +
std::to_string(rotated_boxes[i][4]) + "," +
std::to_string(rotated_boxes[i][5]) + ", " +
std::to_string(rotated_boxes[i][6]) + ", " +
std::to_string(rotated_boxes[i][7]) + ", " +
std::to_string(scores[i]) + ", " + std::to_string(label_ids[i]);
out += "\n";
}
return out;
}

View File

@@ -108,6 +108,9 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
/** \brief All the detected object boxes for an input image, the size of `boxes` is the number of detected objects, and the element of `boxes` is a array of 4 float values, means [xmin, ymin, xmax, ymax]
*/
std::vector<std::array<float, 4>> boxes;
/** \brief All the detected rotated object boxes for an input image, the size of `boxes` is the number of detected objects, and the element of `rotated_boxes` is an array of 8 float values, means [x1, y1, x2, y2, x3, y3, x4, y4]
*/
std::vector<std::array<float, 8>> rotated_boxes;
/** \brief The confidence for all the detected objects
*/
std::vector<float> scores;

View File

@@ -15,6 +15,7 @@
#pragma once
#include "fastdeploy/vision/detection/ppdet/base.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
namespace fastdeploy {
namespace vision {
@@ -439,6 +440,23 @@ class FASTDEPLOY_DECL GFL : public PPDetBase {
virtual std::string ModelName() const { return "PaddleDetection/GFL"; }
};
class FASTDEPLOY_DECL PPYOLOER : public PPDetBase {
public:
PPYOLOER(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = { Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PPYOLOER"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -61,7 +61,7 @@ struct PaddleMultiClassNMS {
const std::vector<int64_t>& boxes_dim,
const std::vector<int64_t>& scores_dim);
void SetNMSOption(const struct NMSOption &nms_option){
void SetNMSOption(const struct NMSOption &nms_option) {
background_label = nms_option.background_label;
keep_top_k = nms_option.keep_top_k;
nms_eta = nms_option.nms_eta;

View File

@@ -0,0 +1,477 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
#include <algorithm>
#include <cmath>
#include <opencv2/opencv.hpp>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
namespace fastdeploy {
namespace vision {
namespace detection {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
Point operator+(const Point& p) const { return Point(x + p.x, y + p.y); }
Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
Point operator-(const Point& p) const { return Point(x - p.x, y - p.y); }
Point operator*(const T coeff) const { return Point(x * coeff, y * coeff); }
};
template <typename T>
T Dot2D(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
T Cross2D(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
int GetIntersectionPoints(const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = Cross2D<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = Cross2D<T>(vec2[j], vec12) / det;
T t2 = Cross2D<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = Dot2D<T>(AB, AB);
auto ADdotAD = Dot2D<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = Dot2D<T>(AP, AB);
auto APdotAD = -Dot2D<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = Dot2D<T>(AB, AB);
auto ADdotAD = Dot2D<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = Dot2D<T>(AP, AB);
auto APdotAD = -Dot2D<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
int ConvexHullGraham(const Point<T> (&p)[24], const int& num_in,
Point<T> (&q)[24], bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = Dot2D<T>(q[i], q[i]);
}
// CPU version
std::sort(q + 1, q + num_in,
[](const Point<T>& A, const Point<T>& B) -> bool {
T temp = Cross2D<T>(A, B);
if (fabs(temp) < 1e-6) {
return Dot2D<T>(A, A) < Dot2D<T>(B, B);
} else {
return temp > 0;
}
});
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && Cross2D<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
T PolygonArea(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(Cross2D<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
T RboxesIntersection(T const* const poly1_raw, T const* const poly2_raw) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
for (int i = 0; i < 4; i++) {
pts1[i] = Point<T>(poly1_raw[2 * i], poly1_raw[2 * i + 1]);
pts2[i] = Point<T>(poly2_raw[2 * i], poly2_raw[2 * i + 1]);
}
int num = GetIntersectionPoints<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = ConvexHullGraham<T>(intersectPts, num, orderedPts, true);
return PolygonArea<T>(orderedPts, num_convex);
}
template <typename T>
T PolyArea(T const* const poly_raw) {
T area = 0.0;
int j = 3;
for (int i = 0; i < 4; i++) {
// area += (x[j] + x[i]) * (y[j] - y[i]);
area += (poly_raw[2 * j] + poly_raw[2 * i]) *
(poly_raw[2 * j + 1] - poly_raw[2 * i + 1]);
j = i;
}
// return static_cast<T>(abs(static_cast<float>(area) / 2.0));
return std::abs(area / 2.0);
}
template <typename T>
void Poly2Rbox(T const* const poly_raw, RotatedBox<T>& box) {
std::vector<cv::Point2f> contour_poly{
cv::Point2f(poly_raw[0], poly_raw[1]),
cv::Point2f(poly_raw[2], poly_raw[3]),
cv::Point2f(poly_raw[4], poly_raw[5]),
cv::Point2f(poly_raw[6], poly_raw[7]),
};
cv::RotatedRect rotate_rect = cv::minAreaRect(contour_poly);
box.x_ctr = rotate_rect.center.x;
box.y_ctr = rotate_rect.center.y;
box.w = rotate_rect.size.width;
box.h = rotate_rect.size.height;
box.a = rotate_rect.angle;
}
template <typename T>
T RboxIouSingle(T const* const poly1_raw, T const* const poly2_raw) {
const T area1 = PolyArea(poly1_raw);
const T area2 = PolyArea(poly2_raw);
const T intersection = RboxesIntersection<T>(poly1_raw, poly2_raw);
const T iou = intersection / (area1 + area2 - intersection);
return iou;
}
template <typename T>
bool SortScorePairDescendRotated(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
void GetMaxScoreIndexRotated(
const float* scores, const int& score_size, const float& threshold,
const int& top_k, std::vector<std::pair<float, int>>* sorted_indices) {
for (size_t i = 0; i < score_size; ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescendRotated<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
void PaddleMultiClassNMSRotated::FastNMSRotated(
const float* boxes, const float* scores, const int& num_boxes,
std::vector<int>* keep_indices) {
std::vector<std::pair<float, int>> sorted_indices;
GetMaxScoreIndexRotated(scores, num_boxes, score_threshold, nms_top_k,
&sorted_indices);
// printf("nms thrd: %f, sort dim: %d\n", nms_threshold,
// int(sorted_indices.size()));
float adaptive_threshold = nms_threshold;
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < keep_indices->size(); ++k) {
if (!keep) {
break;
}
const int kept_idx = (*keep_indices)[k];
float overlap =
RboxIouSingle<float>(boxes + idx * 8, boxes + kept_idx * 8);
keep = overlap <= adaptive_threshold;
}
if (keep) {
keep_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) {
adaptive_threshold *= nms_eta;
}
}
}
int PaddleMultiClassNMSRotated::NMSRotatedForEachSample(
const float* boxes, const float* scores, int num_boxes, int num_classes,
std::map<int, std::vector<int>>* keep_indices) {
for (int i = 0; i < num_classes; ++i) {
if (i == background_label) {
continue;
}
const float* score_for_class_i = scores + i * num_boxes;
FastNMSRotated(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i]));
}
int num_det = 0;
for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) {
num_det += iter->second.size();
}
if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *keep_indices) {
int label = it.first;
const float* current_score = scores + label * num_boxes;
auto& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(current_score[idx], std::make_pair(label, idx)));
}
}
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescendRotated<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
new_indices.swap(*keep_indices);
num_det = keep_top_k;
}
return num_det;
}
void PaddleMultiClassNMSRotated::Compute(
const float* boxes_data, const float* scores_data,
const std::vector<int64_t>& boxes_dim,
const std::vector<int64_t>& scores_dim) {
int score_size = scores_dim.size();
int64_t batch_size = scores_dim[0];
int64_t box_dim = boxes_dim[2];
int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0;
FDASSERT(score_size == 3,
"Require rank of input scores be 3, but now it's %d.", score_size);
FDASSERT(boxes_dim[2] == 8,
"Require the 3-dimension of input boxes be 8, but now it's %lld.",
box_dim);
out_num_rois_data.resize(batch_size);
std::vector<std::map<int, std::vector<int>>> all_indices;
for (size_t i = 0; i < batch_size; ++i) {
std::map<int, std::vector<int>> indices; // indices kept for each class
const float* current_boxes_ptr =
boxes_data + i * boxes_dim[1] * boxes_dim[2];
const float* current_scores_ptr =
scores_data + i * scores_dim[1] * scores_dim[2];
int num = NMSRotatedForEachSample(current_boxes_ptr, current_scores_ptr,
boxes_dim[1], scores_dim[1], &indices);
num_nmsed_out += num;
out_num_rois_data[i] = num;
all_indices.emplace_back(indices);
}
std::vector<int64_t> out_box_dims = {num_nmsed_out, 10};
std::vector<int64_t> out_index_dims = {num_nmsed_out, 1};
if (num_nmsed_out == 0) {
for (size_t i = 0; i < batch_size; ++i) {
out_num_rois_data[i] = 0;
}
return;
}
out_box_data.resize(num_nmsed_out * 10);
out_index_data.resize(num_nmsed_out);
int count = 0;
for (size_t i = 0; i < batch_size; ++i) {
const float* current_boxes_ptr =
boxes_data + i * boxes_dim[1] * boxes_dim[2];
const float* current_scores_ptr =
scores_data + i * scores_dim[1] * scores_dim[2];
for (const auto& it : all_indices[i]) {
int label = it.first;
const auto& indices = it.second;
const float* current_scores_class_ptr =
current_scores_ptr + label * scores_dim[2];
for (size_t j = 0; j < indices.size(); ++j) {
int start = count * 10;
out_box_data[start] = label;
out_box_data[start + 1] = current_scores_class_ptr[indices[j]];
for (int k = 0; k < 8; k++) {
out_box_data[start + 2 + k] = current_boxes_ptr[indices[j] * 8 + k];
}
out_index_data[count] = i * boxes_dim[1] + indices[j];
count += 1;
}
}
}
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace fastdeploy {
namespace vision {
namespace detection {
/** \brief Config for PaddleMultiClassNMSRotated
* \param[in] background_label the value of background label
* \param[in] keep_top_k the value of keep_top_k
* \param[in] nms_eta the value of nms_eta
* \param[in] nms_threshold a dict that contains the arguments of nms operations
* \param[in] nms_top_k if there are more than max_num bboxes after NMS, only top max_num will be kept.
* \param[in] normalized Determine whether normalized is required
* \param[in] score_threshold bbox threshold, bboxes with scores lower than it will not be considered.
*/
struct NMSRotatedOption{
NMSRotatedOption() = default;
int64_t background_label = -1;
int64_t keep_top_k = -1;
float nms_eta = 1.0;
float nms_threshold = 0.1;
int64_t nms_top_k = 2000;
bool normalized = false;
float score_threshold = 0.1;
};
struct PaddleMultiClassNMSRotated {
int64_t background_label = -1;
int64_t keep_top_k = -1;
float nms_eta;
float nms_threshold = 0.1;
int64_t nms_top_k;
bool normalized;
float score_threshold;
std::vector<int32_t> out_num_rois_data;
std::vector<int32_t> out_index_data;
std::vector<float> out_box_data;
void FastNMSRotated(const float* boxes, const float* scores, const int& num_boxes,
std::vector<int>* keep_indices);
int NMSRotatedForEachSample(const float* boxes, const float* scores, int num_boxes,
int num_classes,
std::map<int, std::vector<int>>* keep_indices);
void Compute(const float* ploy_boxes, const float* scores,
const std::vector<int64_t>& boxes_dim,
const std::vector<int64_t>& scores_dim);
void SetNMSRotatedOption(const struct NMSRotatedOption &nms_rotated_option) {
background_label = nms_rotated_option.background_label;
keep_top_k = nms_rotated_option.keep_top_k;
nms_eta = nms_rotated_option.nms_eta;
nms_threshold = nms_rotated_option.nms_threshold;
nms_top_k = nms_rotated_option.nms_top_k;
normalized = nms_rotated_option.normalized;
score_threshold = nms_rotated_option.score_threshold;
}
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -264,6 +264,57 @@ bool PaddleDetPostprocessor::ProcessSolov2(
return true;
}
bool PaddleDetPostprocessor::ProcessPPYOLOER(
const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results) {
if (tensors.size() != 2) {
FDERROR << "The size of tensors for PPYOLOER must be 2." << std::endl;
return false;
}
int boxes_index = 0;
int scores_index = 1;
multi_class_nms_rotated_.Compute(
static_cast<const float*>(tensors[boxes_index].Data()),
static_cast<const float*>(tensors[scores_index].Data()),
tensors[boxes_index].shape, tensors[scores_index].shape);
auto num_boxes = multi_class_nms_rotated_.out_num_rois_data;
auto box_data =
static_cast<const float*>(multi_class_nms_rotated_.out_box_data.data());
// Get boxes for each input image
results->resize(num_boxes.size());
int offset = 0;
for (size_t i = 0; i < num_boxes.size(); ++i) {
const float* ptr = box_data + offset;
(*results)[i].Reserve(num_boxes[i]);
for (size_t j = 0; j < num_boxes[i]; ++j) {
(*results)[i].label_ids.push_back(
static_cast<int32_t>(round(ptr[j * 10])));
(*results)[i].scores.push_back(ptr[j * 10 + 1]);
(*results)[i].rotated_boxes.push_back(std::array<float, 8>(
{ptr[j * 10 + 2], ptr[j * 10 + 3], ptr[j * 10 + 4], ptr[j * 10 + 5],
ptr[j * 10 + 6], ptr[j * 10 + 7], ptr[j * 10 + 8],
ptr[j * 10 + 9]}));
}
offset += (num_boxes[i] * 10);
}
// do scale
if (GetScaleFactor()[0] != 0) {
for (auto& result : *results) {
for (int i = 0; i < result.rotated_boxes.size(); i++) {
for (int j = 0; j < 8; j++) {
auto scale = i % 2 == 0 ? GetScaleFactor()[1] : GetScaleFactor()[0];
result.rotated_boxes[i][j] /= float(scale);
}
}
}
}
return true;
}
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results) {
if (arch_ == "SOLOv2") {
@@ -272,6 +323,10 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
// The fourth output of solov2 is mask
return ProcessMask(tensors[3], results);
} else {
if (tensors[0].Shape().size() == 3 && tensors[0].Shape()[2] == 8) { // PPYOLOER
return ProcessPPYOLOER(tensors, results);
}
// Do process according to whether NMS exists.
if (with_nms_) {
if (!ProcessWithNMS(tensors, results)) {

View File

@@ -16,6 +16,7 @@
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
namespace fastdeploy {
namespace vision {
@@ -28,6 +29,7 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
// There may be no NMS config in the yaml file,
// so we need to give a initial value to multi_class_nms_.
multi_class_nms_.SetNMSOption(NMSOption());
multi_class_nms_rotated_.SetNMSRotatedOption(NMSRotatedOption());
}
/** \brief Create a preprocessor instance for PaddleDet serials model
@@ -40,6 +42,7 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
// There may be no NMS config in the yaml file,
// so we need to give a initial value to multi_class_nms_.
multi_class_nms_.SetNMSOption(NMSOption());
multi_class_nms_rotated_.SetNMSRotatedOption(NMSRotatedOption());
}
/** \brief Process the result of runtime and fill to ClassifyResult structure
@@ -55,6 +58,12 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
/// only available for those model exported without box decoding and nms.
void ApplyNMS() { with_nms_ = false; }
/// If you do not want to modify the Yaml configuration file,
/// you can use this function to set rotated NMS parameters.
void SetNMSRotatedOption(const NMSRotatedOption& option) {
multi_class_nms_rotated_.SetNMSRotatedOption(option);
}
/// If you do not want to modify the Yaml configuration file,
/// you can use this function to set NMS parameters.
void SetNMSOption(const NMSOption& option) {
@@ -79,6 +88,8 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
PaddleMultiClassNMS multi_class_nms_{};
PaddleMultiClassNMSRotated multi_class_nms_rotated_{};
// Process for General tensor without nms.
bool ProcessWithoutNMS(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results);
@@ -91,6 +102,10 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
bool ProcessSolov2(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results);
// Process PPYOLOER
bool ProcessPPYOLOER(const std::vector<FDTensor>& tensors,
std::vector<DetectionResult>* results);
// Process mask tensor for MaskRCNN
bool ProcessMask(const FDTensor& tensor,
std::vector<DetectionResult>* results);

View File

@@ -78,6 +78,11 @@ void BindPPDet(pybind11::module& m) {
vision::detection::NMSOption option) {
self.SetNMSOption(option);
})
.def("set_nms_rotated_option",
[](vision::detection::PaddleDetPostprocessor& self,
vision::detection::NMSRotatedOption option) {
self.SetNMSRotatedOption(option);
})
.def("apply_nms",
[](vision::detection::PaddleDetPostprocessor& self) {
self.ApplyNMS();
@@ -233,5 +238,26 @@ void BindPPDet(pybind11::module& m) {
m, "SOLOv2")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
pybind11::class_<vision::detection::PPYOLOER, vision::detection::PPDetBase>(
m, "PPYOLOER")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
pybind11::class_<vision::detection::NMSRotatedOption>(m, "NMSRotatedOption")
.def(pybind11::init())
.def_readwrite("background_label",
&vision::detection::NMSRotatedOption::background_label)
.def_readwrite("keep_top_k",
&vision::detection::NMSRotatedOption::keep_top_k)
.def_readwrite("nms_eta", &vision::detection::NMSRotatedOption::nms_eta)
.def_readwrite("nms_threshold",
&vision::detection::NMSRotatedOption::nms_threshold)
.def_readwrite("nms_top_k",
&vision::detection::NMSRotatedOption::nms_top_k)
.def_readwrite("normalized",
&vision::detection::NMSRotatedOption::normalized)
.def_readwrite("score_threshold",
&vision::detection::NMSRotatedOption::score_threshold);
}
} // namespace fastdeploy

View File

@@ -84,13 +84,14 @@ void BindVision(pybind11::module& m) {
.def(pybind11::init())
.def_readwrite("boxes", &vision::DetectionResult::boxes)
.def_readwrite("scores", &vision::DetectionResult::scores)
.def_readwrite("rotated_boxes", &vision::DetectionResult::rotated_boxes)
.def_readwrite("label_ids", &vision::DetectionResult::label_ids)
.def_readwrite("masks", &vision::DetectionResult::masks)
.def_readwrite("contain_masks", &vision::DetectionResult::contain_masks)
.def(pybind11::pickle(
[](const vision::DetectionResult& d) {
return pybind11::make_tuple(d.boxes, d.scores, d.label_ids, d.masks,
d.contain_masks);
return pybind11::make_tuple(d.boxes, d.scores, d.rotated_boxes,
d.label_ids, d.masks, d.contain_masks);
},
[](pybind11::tuple t) {
if (t.size() != 5)
@@ -99,6 +100,7 @@ void BindVision(pybind11::module& m) {
vision::DetectionResult d;
d.boxes = t[0].cast<std::vector<std::array<float, 4>>>();
d.rotated_boxes = t[0].cast<std::vector<std::array<float, 8>>>();
d.scores = t[1].cast<std::vector<float>>();
d.label_ids = t[2].cast<std::vector<int32_t>>();
d.masks = t[3].cast<std::vector<vision::Mask>>();

View File

@@ -22,7 +22,7 @@ namespace vision {
cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
float score_threshold, int line_size, float font_size) {
if (result.boxes.empty()) {
if (result.boxes.empty() && result.rotated_boxes.empty()) {
return im;
}
if (result.contain_masks) {
@@ -38,6 +38,45 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
int h = im.rows;
int w = im.cols;
auto vis_im = im.clone();
for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
if (result.scores[i] < score_threshold) {
continue;
}
int c0 = color_map[3 * result.label_ids[i] + 0];
int c1 = color_map[3 * result.label_ids[i] + 1];
int c2 = color_map[3 * result.label_ids[i] + 2];
cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
std::string id = std::to_string(result.label_ids[i]);
std::string score = std::to_string(result.scores[i]);
if (score.size() > 4) {
score = score.substr(0, 4);
}
std::string text = id + ", " + score;
int font = cv::FONT_HERSHEY_SIMPLEX;
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
for (int j = 0; j < 4; j++) {
auto start = cv::Point(
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
cv::Point end;
if (j != 3) {
end = cv::Point(
static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1)])),
static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1) + 1])));
} else {
end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
static_cast<int>(round(result.rotated_boxes[i][1])));
cv::putText(vis_im, text, end, font, font_size,
cv::Scalar(255, 255, 255), 1);
}
cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
0);
}
}
for (size_t i = 0; i < result.boxes.size(); ++i) {
if (result.scores[i] < score_threshold) {
continue;
@@ -125,6 +164,44 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
int h = im.rows;
int w = im.cols;
auto vis_im = im.clone();
for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
if (result.scores[i] < score_threshold) {
continue;
}
int c0 = color_map[3 * result.label_ids[i] + 0];
int c1 = color_map[3 * result.label_ids[i] + 1];
int c2 = color_map[3 * result.label_ids[i] + 2];
cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
std::string id = std::to_string(result.label_ids[i]);
std::string score = std::to_string(result.scores[i]);
if (score.size() > 4) {
score = score.substr(0, 4);
}
std::string text = id + ", " + score;
int font = cv::FONT_HERSHEY_SIMPLEX;
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
for (int j = 0; j < 4; j++) {
auto start = cv::Point(
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
cv::Point end;
if (j == 3) {
end = cv::Point(
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
} else {
end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
static_cast<int>(round(result.rotated_boxes[i][1])));
cv::putText(vis_im, text, end, font, font_size,
cv::Scalar(255, 255, 255), 1);
}
cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
0);
}
}
for (size_t i = 0; i < result.boxes.size(); ++i) {
if (result.scores[i] < score_threshold) {
continue;

View File

@@ -50,6 +50,15 @@ class NMSOption:
return self.nms_option.background_label
class NMSRotatedOption:
def __init__(self):
self.nms_rotated_option = C.vision.detection.NMSRotatedOption()
@property
def background_label(self):
return self.nms_rotated_option.background_label
class PaddleDetPostprocessor:
def __init__(self):
"""Create a postprocessor for PaddleDetection Model
@@ -75,6 +84,14 @@ class PaddleDetPostprocessor:
nms_option = NMSOption()
self._postprocessor.set_nms_option(self, nms_option.nms_option)
def set_nms_rotated_option(self, nms_rotated_option=None):
"""This function will enable decode and rotated nms in postprocess step.
"""
if nms_rotated_option is None:
nms_rotated_option = NMSRotatedOption()
self._postprocessor.set_nms_rotated_option(
self, nms_rotated_option.nms_rotated_option)
class PPYOLOE(FastDeployModel):
def __init__(self,
@@ -781,3 +798,40 @@ class GFL(PPYOLOE):
config_file, self._runtime_option,
model_format)
assert self.initialized, "GFL model initialize failed."
class PPYOLOER(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a PPYOLOER model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ppyoloe_r/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ppyoloe_r/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe_r/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
self._model = C.vision.detection.PPYOLOER(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "PicoDet model initialize failed."
def clone(self):
"""Clone PPYOLOER object
:return: a new PPYOLOER object
"""
class PPYOLOERClone(PPYOLOER):
def __init__(self, model):
self._model = model
clone_model = PPYOLOERClone(self._model.clone())
return clone_model