mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Hackthon_4th 177] Support PP-YOLOE-R with BM1684 (#1809)
* first draft * add robx iou * add benchmark for ppyoloe_r * remove trash code * fix bugs * add pybind nms rotated option * add missing head file * fix bug * fix bug2 * fix shape bug --------- Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,8 @@ add_executable(benchmark_ppmatting ${PROJECT_SOURCE_DIR}/benchmark_ppmatting.cc)
|
||||
add_executable(benchmark_ppocr_det ${PROJECT_SOURCE_DIR}/benchmark_ppocr_det.cc)
|
||||
add_executable(benchmark_ppocr_cls ${PROJECT_SOURCE_DIR}/benchmark_ppocr_cls.cc)
|
||||
add_executable(benchmark_ppocr_rec ${PROJECT_SOURCE_DIR}/benchmark_ppocr_rec.cc)
|
||||
add_executable(benchmark_ppyoloe_r ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r.cc)
|
||||
add_executable(benchmark_ppyoloe_r_sophgo ${PROJECT_SOURCE_DIR}/benchmark_ppyoloe_r_sophgo.cc)
|
||||
add_executable(benchmark_ppyolo ${PROJECT_SOURCE_DIR}/benchmark_ppyolo.cc)
|
||||
add_executable(benchmark_yolov3 ${PROJECT_SOURCE_DIR}/benchmark_yolov3.cc)
|
||||
add_executable(benchmark_fasterrcnn ${PROJECT_SOURCE_DIR}/benchmark_fasterrcnn.cc)
|
||||
@@ -44,6 +46,8 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
|
||||
target_link_libraries(benchmark_ppyolov8 ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppyolox ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppyoloe ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppyoloe_r ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppyoloe_r_sophgo ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_picodet ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppseg ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
@@ -72,6 +76,8 @@ else()
|
||||
target_link_libraries(benchmark_ppyolov8 ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppyolox ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppyoloe ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppyoloe_r ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppyoloe_r_sophgo ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_picodet ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppseg ${FASTDEPLOY_LIBS} gflags)
|
||||
|
60
benchmark/cpp/benchmark_ppyoloe_r.cc
Normal file
60
benchmark/cpp/benchmark_ppyoloe_r.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "flags.h"
|
||||
#include "macros.h"
|
||||
#include "option.h"
|
||||
|
||||
namespace vision = fastdeploy::vision;
|
||||
namespace benchmark = fastdeploy::benchmark;
|
||||
|
||||
DEFINE_bool(no_nms, false, "Whether the model contains nms.");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
|
||||
// Initialization
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
if (!CreateRuntimeOption(&option, argc, argv, true)) {
|
||||
return -1;
|
||||
}
|
||||
auto im = cv::imread(FLAGS_image);
|
||||
std::unordered_map<std::string, std::string> config_info;
|
||||
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
|
||||
&config_info);
|
||||
std::string model_name, params_name, config_name;
|
||||
auto model_format = fastdeploy::ModelFormat::PADDLE;
|
||||
if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name,
|
||||
&model_format, config_info)) {
|
||||
return -1;
|
||||
}
|
||||
auto model_file = FLAGS_model + sep + model_name;
|
||||
auto params_file = FLAGS_model + sep + params_name;
|
||||
auto config_file = FLAGS_model + sep + config_name;
|
||||
|
||||
auto model_ppyoloe_r = vision::detection::PPYOLOER(
|
||||
model_file, params_file, config_file, option, model_format);
|
||||
|
||||
vision::DetectionResult res;
|
||||
|
||||
// Run profiling
|
||||
BENCHMARK_MODEL(model_ppyoloe_r, model_ppyoloe_r.Predict(im, &res))
|
||||
|
||||
auto vis_im = vision::VisDetection(im, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
#endif
|
||||
return 0;
|
||||
}
|
61
benchmark/cpp/benchmark_ppyoloe_r_sophgo.cc
Normal file
61
benchmark/cpp/benchmark_ppyoloe_r_sophgo.cc
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "flags.h"
|
||||
#include "macros.h"
|
||||
#include "option.h"
|
||||
|
||||
namespace vision = fastdeploy::vision;
|
||||
namespace benchmark = fastdeploy::benchmark;
|
||||
|
||||
DEFINE_bool(no_nms, false, "Whether the model contains nms.");
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
|
||||
// Initialization
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
if (!CreateRuntimeOption(&option, argc, argv, true)) {
|
||||
return -1;
|
||||
}
|
||||
auto im = cv::imread(FLAGS_image);
|
||||
std::unordered_map<std::string, std::string> config_info;
|
||||
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
|
||||
&config_info);
|
||||
std::string model_name, params_name, config_name;
|
||||
auto model_format = fastdeploy::ModelFormat::SOPHGO;
|
||||
if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name,
|
||||
&model_format, config_info)) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto model_file = FLAGS_model + sep + model_name;
|
||||
auto params_file = FLAGS_model + sep + params_name;
|
||||
auto config_file = FLAGS_model + sep + config_name;
|
||||
|
||||
auto model_ppyoloe_r = vision::detection::PPYOLOER(
|
||||
model_file, params_file, config_file, option, model_format);
|
||||
|
||||
vision::DetectionResult res;
|
||||
|
||||
// Run profiling
|
||||
BENCHMARK_MODEL(model_ppyoloe_r, model_ppyoloe_r.Predict(im, &res))
|
||||
|
||||
auto vis_im = vision::VisDetection(im, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
#endif
|
||||
return 0;
|
||||
}
|
@@ -158,5 +158,14 @@ static bool UpdateModelResourceName(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (config_info["backend"] == "sophgo") {
|
||||
*model_format = fastdeploy::ModelFormat::SOPHGO;
|
||||
if (!GetModelResoucesNameFromDir(FLAGS_model, model_name, "bmodel")) {
|
||||
std::cout << "Can not find sophgo model resources." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@@ -103,6 +103,9 @@ static bool CreateRuntimeOption(fastdeploy::RuntimeOption* option,
|
||||
if (config_info["use_fp16"] == "true") {
|
||||
option->paddle_lite_option.enable_fp16 = true;
|
||||
}
|
||||
} else if (config_info["backend"] == "sophgo") {
|
||||
option->UseSophgo();
|
||||
option->UseSophgoBackend();
|
||||
} else if (config_info["backend"] == "default") {
|
||||
PrintBenchmarkInfo(config_info);
|
||||
return true;
|
||||
|
@@ -50,6 +50,7 @@ typedef struct FD_C_OneDimMask {
|
||||
|
||||
typedef struct FD_C_DetectionResult {
|
||||
FD_C_TwoDimArrayFloat boxes;
|
||||
FD_C_TwoDimArrayFloat rotated_boxes;
|
||||
FD_C_OneDimArrayFloat scores;
|
||||
FD_C_OneDimArrayInt32 label_ids;
|
||||
FD_C_OneDimMask masks;
|
||||
|
@@ -135,6 +135,7 @@ public struct FD_OneDimMask {
|
||||
[StructLayout(LayoutKind.Sequential)]
|
||||
public struct FD_DetectionResult {
|
||||
public FD_TwoDimArrayFloat boxes;
|
||||
public FD_TwoDimArrayFloat rotated_boxes;
|
||||
public FD_OneDimArrayFloat scores;
|
||||
public FD_OneDimArrayInt32 label_ids;
|
||||
public FD_OneDimMask masks;
|
||||
|
@@ -10,6 +10,7 @@ English | [简体中文](README_CN.md)
|
||||
Now FastDeploy supports the deployment of the following models
|
||||
|
||||
- [PP-YOLOE(including PP-YOLOE+) models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
|
||||
- [PP-YOLOE-R models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
|
||||
- [PicoDet models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
|
||||
- [PP-YOLO models(including v2)](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyolo)
|
||||
- [YOLOv3 models](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3)
|
||||
@@ -43,7 +44,7 @@ Before deployment, PaddleDetection needs to be exported into the deployment mode
|
||||
|
||||
## Download Pre-trained Model
|
||||
|
||||
For developers' testing, models exported by PaddleDetection are provided below. Developers can download them directly.
|
||||
For developers' testing, models exported by PaddleDetection are provided below. Developers can download them directly.
|
||||
|
||||
The accuracy metric is from model descriptions in PaddleDetection. Refer to them for details.
|
||||
|
||||
|
@@ -10,6 +10,7 @@
|
||||
目前FastDeploy支持如下模型的部署
|
||||
|
||||
- [PP-YOLOE(含PP-YOLOE+)系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
|
||||
- [PP-YOLOE-R系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
|
||||
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
|
||||
- [PP-YOLO系列模型(含v2)](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyolo)
|
||||
- [YOLOv3系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3)
|
||||
|
@@ -15,6 +15,9 @@ target_link_libraries(infer_faster_rcnn_demo ${FASTDEPLOY_LIBS})
|
||||
add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
|
||||
target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS})
|
||||
|
||||
add_executable(infer_ppyoloe_r_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe_r.cc)
|
||||
target_link_libraries(infer_ppyoloe_r_demo ${FASTDEPLOY_LIBS})
|
||||
|
||||
add_executable(infer_picodet_demo ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
|
||||
target_link_libraries(infer_picodet_demo ${FASTDEPLOY_LIBS})
|
||||
|
||||
|
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseCpu();
|
||||
auto model = fastdeploy::vision::detection::PPYOLOER(model_file, params_file,
|
||||
config_file, option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
std::cout << im.cols << " vs " << im.rows << std::endl;
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseGpu();
|
||||
auto model = fastdeploy::vision::detection::PPYOLOER(model_file, params_file,
|
||||
config_file, option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
const cv::Mat im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.1);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 4) {
|
||||
std::cout
|
||||
<< "Usage: infer_ppyoloe_r path/to/model_dir path/to/image run_option, "
|
||||
"e.g ./infer_ppyoloe_r ./ppyoloe_model_dir ./test.jpeg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu; 2: run with gpu and use tensorrt backend; 3: run "
|
||||
"with kunlunxin."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[3]) == 0) {
|
||||
CpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 1) {
|
||||
GpuInfer(argv[1], argv[2]);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@@ -0,0 +1,78 @@
|
||||
import cv2
|
||||
import os
|
||||
|
||||
import fastdeploy as fd
|
||||
print(111)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Path of PaddleDetection model directory")
|
||||
parser.add_argument(
|
||||
"--image", default=None, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'kunlunxin', 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "kunlunxin":
|
||||
option.use_kunlunxin()
|
||||
|
||||
if args.device.lower() == "ascend":
|
||||
option.use_ascend()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.use_trt:
|
||||
option.use_trt_backend()
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
if args.model_dir is None:
|
||||
model_dir = fd.download_model(name='ppyoloe_crn_l_300e_coco')
|
||||
else:
|
||||
model_dir = args.model_dir
|
||||
|
||||
model_file = os.path.join(model_dir, "model.pdmodel")
|
||||
params_file = os.path.join(model_dir, "model.pdiparams")
|
||||
config_file = os.path.join(model_dir, "infer_cfg.yml")
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = build_option(args)
|
||||
print(args)
|
||||
model = fd.vision.detection.PPYOLOER(
|
||||
model_file, params_file, config_file, runtime_option=runtime_option)
|
||||
print(2222)
|
||||
|
||||
# 预测图片检测结果
|
||||
if args.image is None:
|
||||
image = fd.utils.get_detection_test_image()
|
||||
else:
|
||||
image = args.image
|
||||
im = cv2.imread(image)
|
||||
result = model.predict(im)
|
||||
print(result)
|
||||
|
||||
# 预测结果可视化
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
@@ -6,6 +6,7 @@
|
||||
- [PP-YOLOE系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
|
||||
- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)
|
||||
- [YOLOV8系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4)
|
||||
- [PP-YOLOE-R系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/rotate/ppyoloe_r)
|
||||
|
||||
## 准备PP-YOLOE YOLOV8或者PicoDet部署模型以及转换模型
|
||||
|
||||
|
@@ -13,9 +13,11 @@ include_directories(${FASTDEPLOY_INCS})
|
||||
include_directories(${FastDeploy_INCLUDE_DIRS})
|
||||
|
||||
add_executable(infer_ppyoloe ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
|
||||
add_executable(infer_ppyoloe_r ${PROJECT_SOURCE_DIR}/infer_ppyoloe_r.cc)
|
||||
add_executable(infer_picodet ${PROJECT_SOURCE_DIR}/infer_picodet.cc)
|
||||
add_executable(infer_yolov8 ${PROJECT_SOURCE_DIR}/infer_yolov8.cc)
|
||||
# 添加FastDeploy库依赖
|
||||
target_link_libraries(infer_ppyoloe ${FASTDEPLOY_LIBS})
|
||||
target_link_libraries(infer_ppyoloe_r ${FASTDEPLOY_LIBS})
|
||||
target_link_libraries(infer_picodet ${FASTDEPLOY_LIBS})
|
||||
target_link_libraries(infer_yolov8 ${FASTDEPLOY_LIBS})
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# PaddleDetection C++部署示例
|
||||
|
||||
本目录下提供`infer_ppyoloe.cc`,`infer_yolov8.cc`和`infer_picodet.cc`快速完成PP-YOLOE模型,YOLOV8模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
|
||||
本目录下提供`infer_ppyoloe.cc`,`infer_ppyoloe_r.cc`,`infer_yolov8.cc`和`infer_picodet.cc`快速完成PP-YOLOE模型,PP-YOLOE-R模型,YOLOV8模型和PicoDet模型在SOPHGO BM1684x板子上加速部署的示例。
|
||||
|
||||
在部署前,需确认以下两个步骤:
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
├── build # 编译文件夹
|
||||
├── image # 存放图片的文件夹
|
||||
├── infer_ppyoloe.cc
|
||||
├── infer_ppyoloe_r.cc
|
||||
├── infer_picodet.cc
|
||||
├── infer_yolov8.cc
|
||||
└── model # 存放模型文件的文件夹
|
||||
@@ -37,6 +38,9 @@
|
||||
```bash
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||
cp 000000014439.jpg ./images
|
||||
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png
|
||||
cp P0861__1.0__1154___824.png ./images
|
||||
```
|
||||
|
||||
### 编译example
|
||||
@@ -53,6 +57,9 @@ make
|
||||
#ppyoloe推理示例
|
||||
./infer_ppyoloe model images/000000014439.jpg
|
||||
|
||||
#ppyoloe_r推理示例
|
||||
./infer_ppyoloe_r model images/P0861__1.0__1154___824.png
|
||||
|
||||
#picodet推理示例
|
||||
./infer_picodet model images/000000014439.jpg
|
||||
```
|
||||
|
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include <sys/time.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
void SophgoInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + "/ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel";
|
||||
auto params_file = "";
|
||||
auto config_file = model_dir + "/infer_cfg.yml";
|
||||
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseSophgo();
|
||||
option.UseSophgoBackend();
|
||||
|
||||
auto format = fastdeploy::ModelFormat::SOPHGO;
|
||||
auto model = fastdeploy::vision::detection::PPYOLOER(
|
||||
model_file, params_file, config_file, option, format);
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << res.Str() << std::endl;
|
||||
auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.1);
|
||||
cv::imwrite("infer_sophgo.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./infer_sophgo.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 3) {
|
||||
std::cout
|
||||
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
|
||||
"e.g ./infer_model ./model_dir ./test.jpeg"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
SophgoInfer(argv[1], argv[2]);
|
||||
return 0;
|
||||
}
|
@@ -13,6 +13,7 @@ cd FastDeploy/examples/vision/detection/paddledetection/sophgo/python
|
||||
|
||||
# 下载图片
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png
|
||||
|
||||
# 推理
|
||||
# ppyoloe推理示例
|
||||
@@ -33,12 +34,19 @@ python3 infer_picodet.py --auto False --pp_detect_path '' --model_file model/ppy
|
||||
python3 infer_yolov8.py --model_file model/yolov8s_s_300e_coco_1684x_f32.bmodel --config_file model/infer_cfg.yml --image ./000000014439.jpg
|
||||
# 运行完成后返回结果如下所示
|
||||
可视化结果存储在sophgo_result.jpg中
|
||||
|
||||
# ppyoloe_r推理示例
|
||||
# 指定--auto True,自动完成模型准备、转换和推理,需要指定PaddleDetection路径
|
||||
python3 infer_ppyoloe_r.py --model_file model/ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel --image P0861__1.0__1154___824.png --config_file model/infer_cfg.yml
|
||||
可视化结果存储在sophgo_result_ppyoloe_r.jpg中
|
||||
```
|
||||
|
||||
## 其它文档
|
||||
- [PP-YOLOE C++部署](../cpp)
|
||||
- [PicoDet C++部署](../cpp)
|
||||
- [YOLOV8 C++部署](../cpp)
|
||||
- [PP-YOLOE-R C++部署](../cpp)
|
||||
- [转换PicoDet SOPHGO模型文档](../README.md)
|
||||
- [转换PP-YOLOE SOPHGO模型文档](../README.md)
|
||||
- [转换YOLOV8 SOPHGO模型文档](../README.md)
|
||||
- [转换PP-YOLOE-R SOPHGO模型文档](../README.md)
|
||||
|
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
from subprocess import run
|
||||
from prepare_npz import prepare
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--auto",
|
||||
action="store_true",
|
||||
help="Auto download, convert, compile and infer if True")
|
||||
parser.add_argument(
|
||||
"--pp_detect_path",
|
||||
default='/workspace/PaddleDetection',
|
||||
help="Path of PaddleDetection folder")
|
||||
parser.add_argument(
|
||||
"--model_file", required=True, help="Path of sophgo model.")
|
||||
parser.add_argument("--config_file", required=True, help="Path of config.")
|
||||
parser.add_argument(
|
||||
"--image", type=str, required=True, help="Path of test image file.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def export_model(args):
|
||||
PPDetection_path = args.pp_detect_path
|
||||
export_str = 'python3 tools/export_model.py \
|
||||
-c configs/rotate/ppyoloe_r/ppyoloe_r_crn_s_3x_dota.yml \
|
||||
-output_dir=output_inference \
|
||||
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_s_3x_dota.pdparams'
|
||||
|
||||
cur_path = os.getcwd()
|
||||
os.chdir(PPDetection_path)
|
||||
print(export_str)
|
||||
run(export_str, shell=True)
|
||||
cp_str = 'cp -r ./output_inference/ppyoloe_crn_s_300e_coco ' + cur_path
|
||||
print(cp_str)
|
||||
run(cp_str, shell=True)
|
||||
os.chdir(cur_path)
|
||||
|
||||
|
||||
def paddle2onnx():
|
||||
convert_str = 'paddle2onnx --model_dir ppyoloe_r_crn_s_3x_dota \
|
||||
--model_filename model.pdmodel \
|
||||
--params_filename model.pdiparams \
|
||||
--save_file ppyoloe_r_crn_s_3x_dota.onnx \
|
||||
--enable_dev_version True'
|
||||
|
||||
print(convert_str)
|
||||
run(convert_str, shell=True)
|
||||
|
||||
|
||||
def mlir_prepare():
|
||||
mlir_path = os.getenv("MODEL_ZOO_PATH")
|
||||
mlir_path = mlir_path[:-13]
|
||||
regression_path = os.path.join(mlir_path, 'regression')
|
||||
mv_str_list = [
|
||||
'mkdir ppyoloe_r', 'cp -rf ' + os.path.join(
|
||||
regression_path, 'dataset/COCO2017/') + ' ./ppyoloe_r',
|
||||
'cp -rf ' + os.path.join(regression_path, 'image/') + ' ./ppyoloe_r',
|
||||
'cp ppyoloe_r_crn_s_3x_dota.onnx ./ppyoloe_r',
|
||||
'mkdir ./ppyoloe_r/workspace'
|
||||
]
|
||||
for str in mv_str_list:
|
||||
print(str)
|
||||
run(str, shell=True)
|
||||
|
||||
|
||||
def image_prepare():
|
||||
img_str = 'wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/P0861__1.0__1154___824.png'
|
||||
if not os.path.exists('P0861__1.0__1154___824.png'):
|
||||
print(img_str)
|
||||
run(img_str, shell=True)
|
||||
prepare('P0861__1.0__1154___824.png', [640, 640])
|
||||
cp_npz_str = 'cp ./inputs.npz ./ppyoloe_r'
|
||||
print(cp_npz_str)
|
||||
run(cp_npz_str, shell=True)
|
||||
|
||||
|
||||
def onnx2mlir():
|
||||
transform_str = 'model_transform.py \
|
||||
--model_name ppyoloe_r_crn_s_3x_dota \
|
||||
--model_def ../ppyoloe_r_crn_s_3x_dota.onnx \
|
||||
--input_shapes [[1,3,1024,1024],[1,2]] \
|
||||
--keep_aspect_ratio \
|
||||
--pixel_format rgb \
|
||||
--mlir ppyoloe_r_crn_s_3x_dota.mlir'
|
||||
|
||||
os.chdir('./ppyoloe_r/workspace')
|
||||
print(transform_str)
|
||||
run(transform_str, shell=True)
|
||||
os.chdir('../../')
|
||||
|
||||
|
||||
def mlir2bmodel():
|
||||
deploy_str = 'model_deploy.py \
|
||||
--mlir ppyoloe_r_crn_s_3x_dota.mlir \
|
||||
--quantize F32 \
|
||||
--chip bm1684x \
|
||||
--model ppyoloe_r_crn_s_3x_dota_1684x_f32.bmodel'
|
||||
|
||||
os.chdir('./ppyoloe_r/workspace')
|
||||
print(deploy_str)
|
||||
run(deploy_str, shell=True)
|
||||
os.chdir('../../')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
|
||||
if args.auto:
|
||||
export_model(args)
|
||||
paddle2onnx()
|
||||
mlir_prepare()
|
||||
image_prepare()
|
||||
onnx2mlir()
|
||||
mlir2bmodel()
|
||||
|
||||
model_file = './ppyoloe/workspace/ppyoloe_crn_s_300e_coco_1684x_f32.bmodel' if args.auto else args.model_file
|
||||
params_file = ""
|
||||
config_file = './ppyoloe_r_crn_s_3x_dota/infer_cfg.yml' if args.auto else args.config_file
|
||||
image_file = './P0861__1.0__1154___824.png' if args.auto else args.image
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = fd.RuntimeOption()
|
||||
runtime_option.use_sophgo()
|
||||
|
||||
model = fd.vision.detection.PPYOLOER(
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=runtime_option,
|
||||
model_format=fd.ModelFormat.SOPHGO)
|
||||
|
||||
# 预测图片分割结果
|
||||
im = cv2.imread(image_file)
|
||||
result = model.predict(im)
|
||||
print(result)
|
||||
|
||||
# 可视化结果
|
||||
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.1)
|
||||
cv2.imwrite("sophgo_result_ppyoloe_r.jpg", vis_im)
|
||||
print("Visualized result save in ./sophgo_result_ppyoloe_r.jpg")
|
@@ -176,6 +176,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
bm_tensor_t input_tensors[input_size];
|
||||
bm_status_t status = BM_SUCCESS;
|
||||
|
||||
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
|
||||
bm_data_type_t* input_dtypes = net_info_->input_dtypes;
|
||||
for (int i = 0; i < input_size; i++) {
|
||||
status = bm_malloc_device_byte(handle_, &input_tensors[i].device_mem,
|
||||
@@ -198,12 +199,14 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
assert(BM_SUCCESS == status);
|
||||
}
|
||||
|
||||
RUNTIME_PROFILE_LOOP_BEGIN(1)
|
||||
bool launch_status = bmrt_launch_tensor_ex(
|
||||
p_bmrt_, net_name_.c_str(), input_tensors, net_info_->input_num,
|
||||
output_tensors, net_info_->output_num, true, false);
|
||||
assert(launch_status);
|
||||
status = bm_thread_sync(handle_);
|
||||
assert(status == BM_SUCCESS);
|
||||
RUNTIME_PROFILE_LOOP_END
|
||||
|
||||
outputs->resize(outputs_desc_.size());
|
||||
bm_data_type_t* output_dtypes = net_info_->output_dtypes;
|
||||
@@ -231,6 +234,7 @@ bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
bm_free_device(handle_, output_tensors[i].device_mem);
|
||||
}
|
||||
RUNTIME_PROFILE_LOOP_H2D_D2H_END
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@@ -83,6 +83,7 @@ std::string Mask::Str() {
|
||||
|
||||
DetectionResult::DetectionResult(const DetectionResult& res) {
|
||||
boxes.assign(res.boxes.begin(), res.boxes.end());
|
||||
rotated_boxes.assign(res.rotated_boxes.begin(), res.rotated_boxes.end());
|
||||
scores.assign(res.scores.begin(), res.scores.end());
|
||||
label_ids.assign(res.label_ids.begin(), res.label_ids.end());
|
||||
contain_masks = res.contain_masks;
|
||||
@@ -98,6 +99,7 @@ DetectionResult::DetectionResult(const DetectionResult& res) {
|
||||
DetectionResult& DetectionResult::operator=(DetectionResult&& other) {
|
||||
if (&other != this) {
|
||||
boxes = std::move(other.boxes);
|
||||
rotated_boxes = std::move(other.rotated_boxes);
|
||||
scores = std::move(other.scores);
|
||||
label_ids = std::move(other.label_ids);
|
||||
contain_masks = std::move(other.contain_masks);
|
||||
@@ -111,6 +113,7 @@ DetectionResult& DetectionResult::operator=(DetectionResult&& other) {
|
||||
|
||||
void DetectionResult::Free() {
|
||||
std::vector<std::array<float, 4>>().swap(boxes);
|
||||
std::vector<std::array<float, 8>>().swap(rotated_boxes);
|
||||
std::vector<float>().swap(scores);
|
||||
std::vector<int32_t>().swap(label_ids);
|
||||
std::vector<Mask>().swap(masks);
|
||||
@@ -119,6 +122,7 @@ void DetectionResult::Free() {
|
||||
|
||||
void DetectionResult::Clear() {
|
||||
boxes.clear();
|
||||
rotated_boxes.clear();
|
||||
scores.clear();
|
||||
label_ids.clear();
|
||||
masks.clear();
|
||||
@@ -127,6 +131,7 @@ void DetectionResult::Clear() {
|
||||
|
||||
void DetectionResult::Reserve(int size) {
|
||||
boxes.reserve(size);
|
||||
rotated_boxes.reserve(size);
|
||||
scores.reserve(size);
|
||||
label_ids.reserve(size);
|
||||
if (contain_masks) {
|
||||
@@ -136,6 +141,7 @@ void DetectionResult::Reserve(int size) {
|
||||
|
||||
void DetectionResult::Resize(int size) {
|
||||
boxes.resize(size);
|
||||
rotated_boxes.resize(size);
|
||||
scores.resize(size);
|
||||
label_ids.resize(size);
|
||||
if (contain_masks) {
|
||||
@@ -163,6 +169,19 @@ std::string DetectionResult::Str() {
|
||||
out += ", " + masks[i].Str();
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < rotated_boxes.size(); ++i) {
|
||||
out = out + std::to_string(rotated_boxes[i][0]) + "," +
|
||||
std::to_string(rotated_boxes[i][1]) + ", " +
|
||||
std::to_string(rotated_boxes[i][2]) + ", " +
|
||||
std::to_string(rotated_boxes[i][3]) + ", " +
|
||||
std::to_string(rotated_boxes[i][4]) + "," +
|
||||
std::to_string(rotated_boxes[i][5]) + ", " +
|
||||
std::to_string(rotated_boxes[i][6]) + ", " +
|
||||
std::to_string(rotated_boxes[i][7]) + ", " +
|
||||
std::to_string(scores[i]) + ", " + std::to_string(label_ids[i]);
|
||||
out += "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@@ -108,6 +108,9 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
|
||||
/** \brief All the detected object boxes for an input image, the size of `boxes` is the number of detected objects, and the element of `boxes` is a array of 4 float values, means [xmin, ymin, xmax, ymax]
|
||||
*/
|
||||
std::vector<std::array<float, 4>> boxes;
|
||||
/** \brief All the detected rotated object boxes for an input image, the size of `boxes` is the number of detected objects, and the element of `rotated_boxes` is an array of 8 float values, means [x1, y1, x2, y2, x3, y3, x4, y4]
|
||||
*/
|
||||
std::vector<std::array<float, 8>> rotated_boxes;
|
||||
/** \brief The confidence for all the detected objects
|
||||
*/
|
||||
std::vector<float> scores;
|
||||
|
@@ -15,6 +15,7 @@
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/detection/ppdet/base.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -439,6 +440,23 @@ class FASTDEPLOY_DECL GFL : public PPDetBase {
|
||||
virtual std::string ModelName() const { return "PaddleDetection/GFL"; }
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL PPYOLOER : public PPDetBase {
|
||||
public:
|
||||
PPYOLOER(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
: PPDetBase(model_file, params_file, config_file, custom_option,
|
||||
model_format) {
|
||||
valid_cpu_backends = { Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
virtual std::string ModelName() const { return "PPYOLOER"; }
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -61,7 +61,7 @@ struct PaddleMultiClassNMS {
|
||||
const std::vector<int64_t>& boxes_dim,
|
||||
const std::vector<int64_t>& scores_dim);
|
||||
|
||||
void SetNMSOption(const struct NMSOption &nms_option){
|
||||
void SetNMSOption(const struct NMSOption &nms_option) {
|
||||
background_label = nms_option.background_label;
|
||||
keep_top_k = nms_option.keep_top_k;
|
||||
nms_eta = nms_option.nms_eta;
|
||||
|
477
fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.cc
Normal file
477
fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.cc
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
template <typename T>
|
||||
struct RotatedBox {
|
||||
T x_ctr, y_ctr, w, h, a;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Point {
|
||||
T x, y;
|
||||
Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
|
||||
Point operator+(const Point& p) const { return Point(x + p.x, y + p.y); }
|
||||
Point& operator+=(const Point& p) {
|
||||
x += p.x;
|
||||
y += p.y;
|
||||
return *this;
|
||||
}
|
||||
Point operator-(const Point& p) const { return Point(x - p.x, y - p.y); }
|
||||
Point operator*(const T coeff) const { return Point(x * coeff, y * coeff); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T Dot2D(const Point<T>& A, const Point<T>& B) {
|
||||
return A.x * B.x + A.y * B.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Cross2D(const Point<T>& A, const Point<T>& B) {
|
||||
return A.x * B.y - B.x * A.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetIntersectionPoints(const Point<T> (&pts1)[4],
|
||||
const Point<T> (&pts2)[4],
|
||||
Point<T> (&intersections)[24]) {
|
||||
// Line vector
|
||||
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
|
||||
Point<T> vec1[4], vec2[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
|
||||
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
|
||||
}
|
||||
|
||||
// Line test - test all line combos for intersection
|
||||
int num = 0; // number of intersections
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
// Solve for 2x2 Ax=b
|
||||
T det = Cross2D<T>(vec2[j], vec1[i]);
|
||||
|
||||
// This takes care of parallel lines
|
||||
if (fabs(det) <= 1e-14) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto vec12 = pts2[j] - pts1[i];
|
||||
|
||||
T t1 = Cross2D<T>(vec2[j], vec12) / det;
|
||||
T t2 = Cross2D<T>(vec1[i], vec12) / det;
|
||||
|
||||
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
|
||||
intersections[num++] = pts1[i] + vec1[i] * t1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for vertices of rect1 inside rect2
|
||||
{
|
||||
const auto& AB = vec2[0];
|
||||
const auto& DA = vec2[3];
|
||||
auto ABdotAB = Dot2D<T>(AB, AB);
|
||||
auto ADdotAD = Dot2D<T>(DA, DA);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
// assume ABCD is the rectangle, and P is the point to be judged
|
||||
// P is inside ABCD iff. P's projection on AB lies within AB
|
||||
// and P's projection on AD lies within AD
|
||||
|
||||
auto AP = pts1[i] - pts2[0];
|
||||
|
||||
auto APdotAB = Dot2D<T>(AP, AB);
|
||||
auto APdotAD = -Dot2D<T>(AP, DA);
|
||||
|
||||
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
||||
(APdotAD <= ADdotAD)) {
|
||||
intersections[num++] = pts1[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse the check - check for vertices of rect2 inside rect1
|
||||
{
|
||||
const auto& AB = vec1[0];
|
||||
const auto& DA = vec1[3];
|
||||
auto ABdotAB = Dot2D<T>(AB, AB);
|
||||
auto ADdotAD = Dot2D<T>(DA, DA);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto AP = pts2[i] - pts1[0];
|
||||
|
||||
auto APdotAB = Dot2D<T>(AP, AB);
|
||||
auto APdotAD = -Dot2D<T>(AP, DA);
|
||||
|
||||
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
||||
(APdotAD <= ADdotAD)) {
|
||||
intersections[num++] = pts2[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return num;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int ConvexHullGraham(const Point<T> (&p)[24], const int& num_in,
|
||||
Point<T> (&q)[24], bool shift_to_zero = false) {
|
||||
assert(num_in >= 2);
|
||||
|
||||
// Step 1:
|
||||
// Find point with minimum y
|
||||
// if more than 1 points have the same minimum y,
|
||||
// pick the one with the minimum x.
|
||||
int t = 0;
|
||||
for (int i = 1; i < num_in; i++) {
|
||||
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
|
||||
t = i;
|
||||
}
|
||||
}
|
||||
auto& start = p[t]; // starting point
|
||||
|
||||
// Step 2:
|
||||
// Subtract starting point from every points (for sorting in the next step)
|
||||
for (int i = 0; i < num_in; i++) {
|
||||
q[i] = p[i] - start;
|
||||
}
|
||||
|
||||
// Swap the starting point to position 0
|
||||
auto tmp = q[0];
|
||||
q[0] = q[t];
|
||||
q[t] = tmp;
|
||||
|
||||
// Step 3:
|
||||
// Sort point 1 ~ num_in according to their relative cross-product values
|
||||
// (essentially sorting according to angles)
|
||||
// If the angles are the same, sort according to their distance to origin
|
||||
T dist[24];
|
||||
for (int i = 0; i < num_in; i++) {
|
||||
dist[i] = Dot2D<T>(q[i], q[i]);
|
||||
}
|
||||
|
||||
// CPU version
|
||||
std::sort(q + 1, q + num_in,
|
||||
[](const Point<T>& A, const Point<T>& B) -> bool {
|
||||
T temp = Cross2D<T>(A, B);
|
||||
if (fabs(temp) < 1e-6) {
|
||||
return Dot2D<T>(A, A) < Dot2D<T>(B, B);
|
||||
} else {
|
||||
return temp > 0;
|
||||
}
|
||||
});
|
||||
|
||||
// Step 4:
|
||||
// Make sure there are at least 2 points (that don't overlap with each other)
|
||||
// in the stack
|
||||
int k; // index of the non-overlapped second point
|
||||
for (k = 1; k < num_in; k++) {
|
||||
if (dist[k] > 1e-8) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == num_in) {
|
||||
// We reach the end, which means the convex hull is just one point
|
||||
q[0] = p[t];
|
||||
return 1;
|
||||
}
|
||||
q[1] = q[k];
|
||||
int m = 2; // 2 points in the stack
|
||||
// Step 5:
|
||||
// Finally we can start the scanning process.
|
||||
// When a non-convex relationship between the 3 points is found
|
||||
// (either concave shape or duplicated points),
|
||||
// we pop the previous point from the stack
|
||||
// until the 3-point relationship is convex again, or
|
||||
// until the stack only contains two points
|
||||
for (int i = k + 1; i < num_in; i++) {
|
||||
while (m > 1 && Cross2D<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
|
||||
m--;
|
||||
}
|
||||
q[m++] = q[i];
|
||||
}
|
||||
|
||||
// Step 6 (Optional):
|
||||
// In general sense we need the original coordinates, so we
|
||||
// need to shift the points back (reverting Step 2)
|
||||
// But if we're only interested in getting the area/perimeter of the shape
|
||||
// We can simply return.
|
||||
if (!shift_to_zero) {
|
||||
for (int i = 0; i < m; i++) {
|
||||
q[i] += start;
|
||||
}
|
||||
}
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T PolygonArea(const Point<T> (&q)[24], const int& m) {
|
||||
if (m <= 2) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
T area = 0;
|
||||
for (int i = 1; i < m - 1; i++) {
|
||||
area += fabs(Cross2D<T>(q[i] - q[0], q[i + 1] - q[0]));
|
||||
}
|
||||
|
||||
return area / 2.0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T RboxesIntersection(T const* const poly1_raw, T const* const poly2_raw) {
|
||||
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
|
||||
// from rotated_rect_intersection_pts
|
||||
Point<T> intersectPts[24], orderedPts[24];
|
||||
|
||||
Point<T> pts1[4];
|
||||
|
||||
Point<T> pts2[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
pts1[i] = Point<T>(poly1_raw[2 * i], poly1_raw[2 * i + 1]);
|
||||
pts2[i] = Point<T>(poly2_raw[2 * i], poly2_raw[2 * i + 1]);
|
||||
}
|
||||
|
||||
int num = GetIntersectionPoints<T>(pts1, pts2, intersectPts);
|
||||
if (num <= 2) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Convex Hull to order the intersection points in clockwise order and find
|
||||
// the contour area.
|
||||
int num_convex = ConvexHullGraham<T>(intersectPts, num, orderedPts, true);
|
||||
return PolygonArea<T>(orderedPts, num_convex);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T PolyArea(T const* const poly_raw) {
|
||||
T area = 0.0;
|
||||
int j = 3;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
// area += (x[j] + x[i]) * (y[j] - y[i]);
|
||||
area += (poly_raw[2 * j] + poly_raw[2 * i]) *
|
||||
(poly_raw[2 * j + 1] - poly_raw[2 * i + 1]);
|
||||
j = i;
|
||||
}
|
||||
// return static_cast<T>(abs(static_cast<float>(area) / 2.0));
|
||||
return std::abs(area / 2.0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Poly2Rbox(T const* const poly_raw, RotatedBox<T>& box) {
|
||||
std::vector<cv::Point2f> contour_poly{
|
||||
cv::Point2f(poly_raw[0], poly_raw[1]),
|
||||
cv::Point2f(poly_raw[2], poly_raw[3]),
|
||||
cv::Point2f(poly_raw[4], poly_raw[5]),
|
||||
cv::Point2f(poly_raw[6], poly_raw[7]),
|
||||
};
|
||||
cv::RotatedRect rotate_rect = cv::minAreaRect(contour_poly);
|
||||
box.x_ctr = rotate_rect.center.x;
|
||||
box.y_ctr = rotate_rect.center.y;
|
||||
box.w = rotate_rect.size.width;
|
||||
box.h = rotate_rect.size.height;
|
||||
box.a = rotate_rect.angle;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T RboxIouSingle(T const* const poly1_raw, T const* const poly2_raw) {
|
||||
const T area1 = PolyArea(poly1_raw);
|
||||
const T area2 = PolyArea(poly2_raw);
|
||||
|
||||
const T intersection = RboxesIntersection<T>(poly1_raw, poly2_raw);
|
||||
const T iou = intersection / (area1 + area2 - intersection);
|
||||
return iou;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SortScorePairDescendRotated(const std::pair<float, T>& pair1,
|
||||
const std::pair<float, T>& pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
}
|
||||
|
||||
void GetMaxScoreIndexRotated(
|
||||
const float* scores, const int& score_size, const float& threshold,
|
||||
const int& top_k, std::vector<std::pair<float, int>>* sorted_indices) {
|
||||
for (size_t i = 0; i < score_size; ++i) {
|
||||
if (scores[i] > threshold) {
|
||||
sorted_indices->push_back(std::make_pair(scores[i], i));
|
||||
}
|
||||
}
|
||||
// Sort the score pair according to the scores in descending order
|
||||
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
|
||||
SortScorePairDescendRotated<int>);
|
||||
// Keep top_k scores if needed.
|
||||
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
|
||||
sorted_indices->resize(top_k);
|
||||
}
|
||||
}
|
||||
|
||||
void PaddleMultiClassNMSRotated::FastNMSRotated(
|
||||
const float* boxes, const float* scores, const int& num_boxes,
|
||||
std::vector<int>* keep_indices) {
|
||||
std::vector<std::pair<float, int>> sorted_indices;
|
||||
GetMaxScoreIndexRotated(scores, num_boxes, score_threshold, nms_top_k,
|
||||
&sorted_indices);
|
||||
// printf("nms thrd: %f, sort dim: %d\n", nms_threshold,
|
||||
// int(sorted_indices.size()));
|
||||
float adaptive_threshold = nms_threshold;
|
||||
while (sorted_indices.size() != 0) {
|
||||
const int idx = sorted_indices.front().second;
|
||||
bool keep = true;
|
||||
for (size_t k = 0; k < keep_indices->size(); ++k) {
|
||||
if (!keep) {
|
||||
break;
|
||||
}
|
||||
const int kept_idx = (*keep_indices)[k];
|
||||
float overlap =
|
||||
RboxIouSingle<float>(boxes + idx * 8, boxes + kept_idx * 8);
|
||||
|
||||
keep = overlap <= adaptive_threshold;
|
||||
}
|
||||
if (keep) {
|
||||
keep_indices->push_back(idx);
|
||||
}
|
||||
sorted_indices.erase(sorted_indices.begin());
|
||||
if (keep && nms_eta<1.0 & adaptive_threshold> 0.5) {
|
||||
adaptive_threshold *= nms_eta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int PaddleMultiClassNMSRotated::NMSRotatedForEachSample(
|
||||
const float* boxes, const float* scores, int num_boxes, int num_classes,
|
||||
std::map<int, std::vector<int>>* keep_indices) {
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
if (i == background_label) {
|
||||
continue;
|
||||
}
|
||||
const float* score_for_class_i = scores + i * num_boxes;
|
||||
FastNMSRotated(boxes, score_for_class_i, num_boxes, &((*keep_indices)[i]));
|
||||
}
|
||||
int num_det = 0;
|
||||
for (auto iter = keep_indices->begin(); iter != keep_indices->end(); ++iter) {
|
||||
num_det += iter->second.size();
|
||||
}
|
||||
|
||||
if (keep_top_k > -1 && num_det > keep_top_k) {
|
||||
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
|
||||
for (const auto& it : *keep_indices) {
|
||||
int label = it.first;
|
||||
const float* current_score = scores + label * num_boxes;
|
||||
auto& label_indices = it.second;
|
||||
for (size_t j = 0; j < label_indices.size(); ++j) {
|
||||
int idx = label_indices[j];
|
||||
score_index_pairs.push_back(
|
||||
std::make_pair(current_score[idx], std::make_pair(label, idx)));
|
||||
}
|
||||
}
|
||||
|
||||
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
|
||||
SortScorePairDescendRotated<std::pair<int, int>>);
|
||||
score_index_pairs.resize(keep_top_k);
|
||||
|
||||
std::map<int, std::vector<int>> new_indices;
|
||||
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
|
||||
int label = score_index_pairs[j].second.first;
|
||||
int idx = score_index_pairs[j].second.second;
|
||||
new_indices[label].push_back(idx);
|
||||
}
|
||||
new_indices.swap(*keep_indices);
|
||||
num_det = keep_top_k;
|
||||
}
|
||||
return num_det;
|
||||
}
|
||||
|
||||
void PaddleMultiClassNMSRotated::Compute(
|
||||
const float* boxes_data, const float* scores_data,
|
||||
const std::vector<int64_t>& boxes_dim,
|
||||
const std::vector<int64_t>& scores_dim) {
|
||||
int score_size = scores_dim.size();
|
||||
|
||||
int64_t batch_size = scores_dim[0];
|
||||
int64_t box_dim = boxes_dim[2];
|
||||
int64_t out_dim = box_dim + 2;
|
||||
|
||||
int num_nmsed_out = 0;
|
||||
FDASSERT(score_size == 3,
|
||||
"Require rank of input scores be 3, but now it's %d.", score_size);
|
||||
FDASSERT(boxes_dim[2] == 8,
|
||||
"Require the 3-dimension of input boxes be 8, but now it's %lld.",
|
||||
box_dim);
|
||||
out_num_rois_data.resize(batch_size);
|
||||
|
||||
std::vector<std::map<int, std::vector<int>>> all_indices;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
std::map<int, std::vector<int>> indices; // indices kept for each class
|
||||
const float* current_boxes_ptr =
|
||||
boxes_data + i * boxes_dim[1] * boxes_dim[2];
|
||||
const float* current_scores_ptr =
|
||||
scores_data + i * scores_dim[1] * scores_dim[2];
|
||||
int num = NMSRotatedForEachSample(current_boxes_ptr, current_scores_ptr,
|
||||
boxes_dim[1], scores_dim[1], &indices);
|
||||
num_nmsed_out += num;
|
||||
out_num_rois_data[i] = num;
|
||||
all_indices.emplace_back(indices);
|
||||
}
|
||||
std::vector<int64_t> out_box_dims = {num_nmsed_out, 10};
|
||||
std::vector<int64_t> out_index_dims = {num_nmsed_out, 1};
|
||||
if (num_nmsed_out == 0) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
out_num_rois_data[i] = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
out_box_data.resize(num_nmsed_out * 10);
|
||||
out_index_data.resize(num_nmsed_out);
|
||||
|
||||
int count = 0;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
const float* current_boxes_ptr =
|
||||
boxes_data + i * boxes_dim[1] * boxes_dim[2];
|
||||
const float* current_scores_ptr =
|
||||
scores_data + i * scores_dim[1] * scores_dim[2];
|
||||
for (const auto& it : all_indices[i]) {
|
||||
int label = it.first;
|
||||
const auto& indices = it.second;
|
||||
const float* current_scores_class_ptr =
|
||||
current_scores_ptr + label * scores_dim[2];
|
||||
for (size_t j = 0; j < indices.size(); ++j) {
|
||||
int start = count * 10;
|
||||
out_box_data[start] = label;
|
||||
out_box_data[start + 1] = current_scores_class_ptr[indices[j]];
|
||||
for (int k = 0; k < 8; k++) {
|
||||
out_box_data[start + 2 + k] = current_boxes_ptr[indices[j] * 8 + k];
|
||||
}
|
||||
out_index_data[count] = i * boxes_dim[1] + indices[j];
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
76
fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h
Normal file
76
fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
/** \brief Config for PaddleMultiClassNMSRotated
|
||||
* \param[in] background_label the value of background label
|
||||
* \param[in] keep_top_k the value of keep_top_k
|
||||
* \param[in] nms_eta the value of nms_eta
|
||||
* \param[in] nms_threshold a dict that contains the arguments of nms operations
|
||||
* \param[in] nms_top_k if there are more than max_num bboxes after NMS, only top max_num will be kept.
|
||||
* \param[in] normalized Determine whether normalized is required
|
||||
* \param[in] score_threshold bbox threshold, bboxes with scores lower than it will not be considered.
|
||||
*/
|
||||
struct NMSRotatedOption{
|
||||
NMSRotatedOption() = default;
|
||||
int64_t background_label = -1;
|
||||
int64_t keep_top_k = -1;
|
||||
float nms_eta = 1.0;
|
||||
float nms_threshold = 0.1;
|
||||
int64_t nms_top_k = 2000;
|
||||
bool normalized = false;
|
||||
float score_threshold = 0.1;
|
||||
};
|
||||
|
||||
struct PaddleMultiClassNMSRotated {
|
||||
int64_t background_label = -1;
|
||||
int64_t keep_top_k = -1;
|
||||
float nms_eta;
|
||||
float nms_threshold = 0.1;
|
||||
int64_t nms_top_k;
|
||||
bool normalized;
|
||||
float score_threshold;
|
||||
|
||||
std::vector<int32_t> out_num_rois_data;
|
||||
std::vector<int32_t> out_index_data;
|
||||
std::vector<float> out_box_data;
|
||||
void FastNMSRotated(const float* boxes, const float* scores, const int& num_boxes,
|
||||
std::vector<int>* keep_indices);
|
||||
int NMSRotatedForEachSample(const float* boxes, const float* scores, int num_boxes,
|
||||
int num_classes,
|
||||
std::map<int, std::vector<int>>* keep_indices);
|
||||
void Compute(const float* ploy_boxes, const float* scores,
|
||||
const std::vector<int64_t>& boxes_dim,
|
||||
const std::vector<int64_t>& scores_dim);
|
||||
|
||||
void SetNMSRotatedOption(const struct NMSRotatedOption &nms_rotated_option) {
|
||||
background_label = nms_rotated_option.background_label;
|
||||
keep_top_k = nms_rotated_option.keep_top_k;
|
||||
nms_eta = nms_rotated_option.nms_eta;
|
||||
nms_threshold = nms_rotated_option.nms_threshold;
|
||||
nms_top_k = nms_rotated_option.nms_top_k;
|
||||
normalized = nms_rotated_option.normalized;
|
||||
score_threshold = nms_rotated_option.score_threshold;
|
||||
}
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -264,6 +264,57 @@ bool PaddleDetPostprocessor::ProcessSolov2(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::ProcessPPYOLOER(
|
||||
const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (tensors.size() != 2) {
|
||||
FDERROR << "The size of tensors for PPYOLOER must be 2." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int boxes_index = 0;
|
||||
int scores_index = 1;
|
||||
multi_class_nms_rotated_.Compute(
|
||||
static_cast<const float*>(tensors[boxes_index].Data()),
|
||||
static_cast<const float*>(tensors[scores_index].Data()),
|
||||
tensors[boxes_index].shape, tensors[scores_index].shape);
|
||||
auto num_boxes = multi_class_nms_rotated_.out_num_rois_data;
|
||||
auto box_data =
|
||||
static_cast<const float*>(multi_class_nms_rotated_.out_box_data.data());
|
||||
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
(*results)[i].label_ids.push_back(
|
||||
static_cast<int32_t>(round(ptr[j * 10])));
|
||||
(*results)[i].scores.push_back(ptr[j * 10 + 1]);
|
||||
(*results)[i].rotated_boxes.push_back(std::array<float, 8>(
|
||||
{ptr[j * 10 + 2], ptr[j * 10 + 3], ptr[j * 10 + 4], ptr[j * 10 + 5],
|
||||
ptr[j * 10 + 6], ptr[j * 10 + 7], ptr[j * 10 + 8],
|
||||
ptr[j * 10 + 9]}));
|
||||
}
|
||||
offset += (num_boxes[i] * 10);
|
||||
}
|
||||
|
||||
// do scale
|
||||
if (GetScaleFactor()[0] != 0) {
|
||||
for (auto& result : *results) {
|
||||
for (int i = 0; i < result.rotated_boxes.size(); i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
auto scale = i % 2 == 0 ? GetScaleFactor()[1] : GetScaleFactor()[0];
|
||||
result.rotated_boxes[i][j] /= float(scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results) {
|
||||
if (arch_ == "SOLOv2") {
|
||||
@@ -272,6 +323,10 @@ bool PaddleDetPostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
// The fourth output of solov2 is mask
|
||||
return ProcessMask(tensors[3], results);
|
||||
} else {
|
||||
if (tensors[0].Shape().size() == 3 && tensors[0].Shape()[2] == 8) { // PPYOLOER
|
||||
return ProcessPPYOLOER(tensors, results);
|
||||
}
|
||||
|
||||
// Do process according to whether NMS exists.
|
||||
if (with_nms_) {
|
||||
if (!ProcessWithNMS(tensors, results)) {
|
||||
|
@@ -16,6 +16,7 @@
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms_rotated.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -28,6 +29,7 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
// There may be no NMS config in the yaml file,
|
||||
// so we need to give a initial value to multi_class_nms_.
|
||||
multi_class_nms_.SetNMSOption(NMSOption());
|
||||
multi_class_nms_rotated_.SetNMSRotatedOption(NMSRotatedOption());
|
||||
}
|
||||
|
||||
/** \brief Create a preprocessor instance for PaddleDet serials model
|
||||
@@ -40,6 +42,7 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
// There may be no NMS config in the yaml file,
|
||||
// so we need to give a initial value to multi_class_nms_.
|
||||
multi_class_nms_.SetNMSOption(NMSOption());
|
||||
multi_class_nms_rotated_.SetNMSRotatedOption(NMSRotatedOption());
|
||||
}
|
||||
|
||||
/** \brief Process the result of runtime and fill to ClassifyResult structure
|
||||
@@ -55,6 +58,12 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
/// only available for those model exported without box decoding and nms.
|
||||
void ApplyNMS() { with_nms_ = false; }
|
||||
|
||||
/// If you do not want to modify the Yaml configuration file,
|
||||
/// you can use this function to set rotated NMS parameters.
|
||||
void SetNMSRotatedOption(const NMSRotatedOption& option) {
|
||||
multi_class_nms_rotated_.SetNMSRotatedOption(option);
|
||||
}
|
||||
|
||||
/// If you do not want to modify the Yaml configuration file,
|
||||
/// you can use this function to set NMS parameters.
|
||||
void SetNMSOption(const NMSOption& option) {
|
||||
@@ -79,6 +88,8 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
|
||||
PaddleMultiClassNMS multi_class_nms_{};
|
||||
|
||||
PaddleMultiClassNMSRotated multi_class_nms_rotated_{};
|
||||
|
||||
// Process for General tensor without nms.
|
||||
bool ProcessWithoutNMS(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
@@ -91,6 +102,10 @@ class FASTDEPLOY_DECL PaddleDetPostprocessor {
|
||||
bool ProcessSolov2(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
||||
// Process PPYOLOER
|
||||
bool ProcessPPYOLOER(const std::vector<FDTensor>& tensors,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
||||
// Process mask tensor for MaskRCNN
|
||||
bool ProcessMask(const FDTensor& tensor,
|
||||
std::vector<DetectionResult>* results);
|
||||
|
@@ -78,6 +78,11 @@ void BindPPDet(pybind11::module& m) {
|
||||
vision::detection::NMSOption option) {
|
||||
self.SetNMSOption(option);
|
||||
})
|
||||
.def("set_nms_rotated_option",
|
||||
[](vision::detection::PaddleDetPostprocessor& self,
|
||||
vision::detection::NMSRotatedOption option) {
|
||||
self.SetNMSRotatedOption(option);
|
||||
})
|
||||
.def("apply_nms",
|
||||
[](vision::detection::PaddleDetPostprocessor& self) {
|
||||
self.ApplyNMS();
|
||||
@@ -233,5 +238,26 @@ void BindPPDet(pybind11::module& m) {
|
||||
m, "SOLOv2")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::PPYOLOER, vision::detection::PPDetBase>(
|
||||
m, "PPYOLOER")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>());
|
||||
|
||||
pybind11::class_<vision::detection::NMSRotatedOption>(m, "NMSRotatedOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("background_label",
|
||||
&vision::detection::NMSRotatedOption::background_label)
|
||||
.def_readwrite("keep_top_k",
|
||||
&vision::detection::NMSRotatedOption::keep_top_k)
|
||||
.def_readwrite("nms_eta", &vision::detection::NMSRotatedOption::nms_eta)
|
||||
.def_readwrite("nms_threshold",
|
||||
&vision::detection::NMSRotatedOption::nms_threshold)
|
||||
.def_readwrite("nms_top_k",
|
||||
&vision::detection::NMSRotatedOption::nms_top_k)
|
||||
.def_readwrite("normalized",
|
||||
&vision::detection::NMSRotatedOption::normalized)
|
||||
.def_readwrite("score_threshold",
|
||||
&vision::detection::NMSRotatedOption::score_threshold);
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
@@ -84,13 +84,14 @@ void BindVision(pybind11::module& m) {
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("boxes", &vision::DetectionResult::boxes)
|
||||
.def_readwrite("scores", &vision::DetectionResult::scores)
|
||||
.def_readwrite("rotated_boxes", &vision::DetectionResult::rotated_boxes)
|
||||
.def_readwrite("label_ids", &vision::DetectionResult::label_ids)
|
||||
.def_readwrite("masks", &vision::DetectionResult::masks)
|
||||
.def_readwrite("contain_masks", &vision::DetectionResult::contain_masks)
|
||||
.def(pybind11::pickle(
|
||||
[](const vision::DetectionResult& d) {
|
||||
return pybind11::make_tuple(d.boxes, d.scores, d.label_ids, d.masks,
|
||||
d.contain_masks);
|
||||
return pybind11::make_tuple(d.boxes, d.scores, d.rotated_boxes,
|
||||
d.label_ids, d.masks, d.contain_masks);
|
||||
},
|
||||
[](pybind11::tuple t) {
|
||||
if (t.size() != 5)
|
||||
@@ -99,6 +100,7 @@ void BindVision(pybind11::module& m) {
|
||||
|
||||
vision::DetectionResult d;
|
||||
d.boxes = t[0].cast<std::vector<std::array<float, 4>>>();
|
||||
d.rotated_boxes = t[0].cast<std::vector<std::array<float, 8>>>();
|
||||
d.scores = t[1].cast<std::vector<float>>();
|
||||
d.label_ids = t[2].cast<std::vector<int32_t>>();
|
||||
d.masks = t[3].cast<std::vector<vision::Mask>>();
|
||||
|
@@ -22,7 +22,7 @@ namespace vision {
|
||||
|
||||
cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
|
||||
float score_threshold, int line_size, float font_size) {
|
||||
if (result.boxes.empty()) {
|
||||
if (result.boxes.empty() && result.rotated_boxes.empty()) {
|
||||
return im;
|
||||
}
|
||||
if (result.contain_masks) {
|
||||
@@ -38,6 +38,45 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
|
||||
int h = im.rows;
|
||||
int w = im.cols;
|
||||
auto vis_im = im.clone();
|
||||
for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
|
||||
if (result.scores[i] < score_threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int c0 = color_map[3 * result.label_ids[i] + 0];
|
||||
int c1 = color_map[3 * result.label_ids[i] + 1];
|
||||
int c2 = color_map[3 * result.label_ids[i] + 2];
|
||||
cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
|
||||
std::string id = std::to_string(result.label_ids[i]);
|
||||
std::string score = std::to_string(result.scores[i]);
|
||||
if (score.size() > 4) {
|
||||
score = score.substr(0, 4);
|
||||
}
|
||||
std::string text = id + ", " + score;
|
||||
int font = cv::FONT_HERSHEY_SIMPLEX;
|
||||
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
|
||||
|
||||
for (int j = 0; j < 4; j++) {
|
||||
auto start = cv::Point(
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
|
||||
|
||||
cv::Point end;
|
||||
if (j != 3) {
|
||||
end = cv::Point(
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1)])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * (j + 1) + 1])));
|
||||
} else {
|
||||
end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][1])));
|
||||
cv::putText(vis_im, text, end, font, font_size,
|
||||
cv::Scalar(255, 255, 255), 1);
|
||||
}
|
||||
cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
|
||||
0);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < result.boxes.size(); ++i) {
|
||||
if (result.scores[i] < score_threshold) {
|
||||
continue;
|
||||
@@ -125,6 +164,44 @@ cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result,
|
||||
int h = im.rows;
|
||||
int w = im.cols;
|
||||
auto vis_im = im.clone();
|
||||
for (size_t i = 0; i < result.rotated_boxes.size(); ++i) {
|
||||
if (result.scores[i] < score_threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int c0 = color_map[3 * result.label_ids[i] + 0];
|
||||
int c1 = color_map[3 * result.label_ids[i] + 1];
|
||||
int c2 = color_map[3 * result.label_ids[i] + 2];
|
||||
cv::Scalar rect_color = cv::Scalar(c0, c1, c2);
|
||||
std::string id = std::to_string(result.label_ids[i]);
|
||||
std::string score = std::to_string(result.scores[i]);
|
||||
if (score.size() > 4) {
|
||||
score = score.substr(0, 4);
|
||||
}
|
||||
std::string text = id + ", " + score;
|
||||
int font = cv::FONT_HERSHEY_SIMPLEX;
|
||||
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
|
||||
|
||||
for (int j = 0; j < 4; j++) {
|
||||
auto start = cv::Point(
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
|
||||
|
||||
cv::Point end;
|
||||
if (j == 3) {
|
||||
end = cv::Point(
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][2 * j + 1])));
|
||||
} else {
|
||||
end = cv::Point(static_cast<int>(round(result.rotated_boxes[i][0])),
|
||||
static_cast<int>(round(result.rotated_boxes[i][1])));
|
||||
cv::putText(vis_im, text, end, font, font_size,
|
||||
cv::Scalar(255, 255, 255), 1);
|
||||
}
|
||||
cv::line(vis_im, start, end, cv::Scalar(255, 255, 255), 3, cv::LINE_AA,
|
||||
0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < result.boxes.size(); ++i) {
|
||||
if (result.scores[i] < score_threshold) {
|
||||
continue;
|
||||
|
@@ -50,6 +50,15 @@ class NMSOption:
|
||||
return self.nms_option.background_label
|
||||
|
||||
|
||||
class NMSRotatedOption:
|
||||
def __init__(self):
|
||||
self.nms_rotated_option = C.vision.detection.NMSRotatedOption()
|
||||
|
||||
@property
|
||||
def background_label(self):
|
||||
return self.nms_rotated_option.background_label
|
||||
|
||||
|
||||
class PaddleDetPostprocessor:
|
||||
def __init__(self):
|
||||
"""Create a postprocessor for PaddleDetection Model
|
||||
@@ -75,6 +84,14 @@ class PaddleDetPostprocessor:
|
||||
nms_option = NMSOption()
|
||||
self._postprocessor.set_nms_option(self, nms_option.nms_option)
|
||||
|
||||
def set_nms_rotated_option(self, nms_rotated_option=None):
|
||||
"""This function will enable decode and rotated nms in postprocess step.
|
||||
"""
|
||||
if nms_rotated_option is None:
|
||||
nms_rotated_option = NMSRotatedOption()
|
||||
self._postprocessor.set_nms_rotated_option(
|
||||
self, nms_rotated_option.nms_rotated_option)
|
||||
|
||||
|
||||
class PPYOLOE(FastDeployModel):
|
||||
def __init__(self,
|
||||
@@ -781,3 +798,40 @@ class GFL(PPYOLOE):
|
||||
config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "GFL model initialize failed."
|
||||
|
||||
|
||||
class PPYOLOER(PPYOLOE):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
config_file,
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.PADDLE):
|
||||
"""Load a PPYOLOER model exported by PaddleDetection.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g ppyoloe_r/model.pdmodel
|
||||
:param params_file: (str)Path of parameters file, e.g ppyoloe_r/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe_r/infer_cfg.yml
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||
"""
|
||||
|
||||
super(PPYOLOE, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.detection.PPYOLOER(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
assert self.initialized, "PicoDet model initialize failed."
|
||||
|
||||
def clone(self):
|
||||
"""Clone PPYOLOER object
|
||||
|
||||
:return: a new PPYOLOER object
|
||||
"""
|
||||
|
||||
class PPYOLOERClone(PPYOLOER):
|
||||
def __init__(self, model):
|
||||
self._model = model
|
||||
|
||||
clone_model = PPYOLOERClone(self._model.clone())
|
||||
return clone_model
|
||||
|
Reference in New Issue
Block a user