[feature][vision] Add YOLOv7 End2End model with TRT NMS (#157)

* [feature][vision] Add YOLOv7 End2End model with TRT NMS

* [docs] update yolov7end2end_trt examples docs

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
DefTruth
2022-08-30 15:02:48 +08:00
committed by GitHub
parent 30bb233db8
commit 3c1330e896
14 changed files with 902 additions and 2 deletions

View File

@@ -27,7 +27,9 @@ bool FastDeployModel::InitRuntime() {
} }
if (runtime_option.backend != Backend::UNKNOWN) { if (runtime_option.backend != Backend::UNKNOWN) {
if (!IsBackendAvailable(runtime_option.backend)) { if (!IsBackendAvailable(runtime_option.backend)) {
FDERROR << Str(runtime_option.backend) << " is not compiled with current FastDeploy library." << std::endl; FDERROR << Str(runtime_option.backend)
<< " is not compiled with current FastDeploy library."
<< std::endl;
return false; return false;
} }
@@ -70,7 +72,7 @@ bool FastDeployModel::InitRuntime() {
FDWARNING << "FastDeploy will choose " << Str(valid_gpu_backends[0]) FDWARNING << "FastDeploy will choose " << Str(valid_gpu_backends[0])
<< " for model inference." << std::endl; << " for model inference." << std::endl;
} else { } else {
FDASSERT(valid_gpu_backends.size() > 0, FDASSERT(valid_cpu_backends.size() > 0,
"There's no valid cpu backend for %s.", ModelName().c_str()); "There's no valid cpu backend for %s.", ModelName().c_str());
FDWARNING << "FastDeploy will choose " << Str(valid_cpu_backends[0]) FDWARNING << "FastDeploy will choose " << Str(valid_cpu_backends[0])
<< " for model inference." << std::endl; << " for model inference." << std::endl;

View File

@@ -23,6 +23,7 @@
#include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov5lite.h"
#include "fastdeploy/vision/detection/contrib/yolov6.h" #include "fastdeploy/vision/detection/contrib/yolov6.h"
#include "fastdeploy/vision/detection/contrib/yolov7.h" #include "fastdeploy/vision/detection/contrib/yolov7.h"
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h"
#include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/detection/ppdet/model.h"

View File

@@ -0,0 +1,267 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace detection {
void YOLOv7End2EndTRT::LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto,
bool scale_fill, bool scale_up, int stride) {
float scale =
std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width());
if (!scale_up) {
scale = std::min(scale, 1.0f);
}
int resize_h = int(round(mat->Height() * scale));
int resize_w = int(round(mat->Width() * scale));
int pad_w = size[0] - resize_w;
int pad_h = size[1] - resize_h;
if (_auto) {
pad_h = pad_h % stride;
pad_w = pad_w % stride;
} else if (scale_fill) {
pad_h = 0;
pad_w = 0;
resize_h = size[1];
resize_w = size[0];
}
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}
if (pad_h > 0 || pad_w > 0) {
float half_h = pad_h * 1.0 / 2;
int top = int(round(half_h - 0.1));
int bottom = int(round(half_h + 0.1));
float half_w = pad_w * 1.0 / 2;
int left = int(round(half_w - 0.1));
int right = int(round(half_w + 0.1));
Pad::Run(mat, top, bottom, left, right, color);
}
}
YOLOv7End2EndTRT::YOLOv7End2EndTRT(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {}; // NO CPU
valid_gpu_backends = {Backend::TRT}; // NO ORT
} else {
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
if (runtime_option.device != Device::GPU) {
FDWARNING << Str(runtime_option.device)
<< " is not support for YOLOv7End2EndTRT,"
<< "will fallback to Device::GPU." << std::endl;
runtime_option.device = Device::GPU;
}
if (runtime_option.backend != Backend::UNKNOWN) {
if (runtime_option.backend != Backend::TRT) {
FDWARNING << Str(runtime_option.backend)
<< " is not support for YOLOv7End2EndTRT,"
<< "will fallback to Backend::TRT." << std::endl;
runtime_option.backend = Backend::TRT;
}
}
initialized = Initialize();
}
bool YOLOv7End2EndTRT::Initialize() {
// parameters for preprocess
size = {640, 640};
padding_value = {114.0, 114.0, 114.0};
is_mini_pad = false;
is_no_pad = false;
is_scale_up = false;
stride = 32;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}
bool YOLOv7End2EndTRT::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
float ratio = std::min(size[1] * 1.0f / static_cast<float>(mat->Height()),
size[0] * 1.0f / static_cast<float>(mat->Width()));
if (ratio != 1.0) {
int interp = cv::INTER_AREA;
if (ratio > 1.0) {
interp = cv::INTER_LINEAR;
}
int resize_h = int(mat->Height() * ratio);
int resize_w = int(mat->Width() * ratio);
Resize::Run(mat, resize_w, resize_h, -1, -1, interp);
}
YOLOv7End2EndTRT::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad,
is_scale_up, stride);
BGR2RGB::Run(mat);
std::vector<float> alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
std::vector<float> beta = {0.0f, 0.0f, 0.0f};
Convert::Run(mat, alpha, beta);
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}
bool YOLOv7End2EndTRT::Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold) {
FDASSERT(infer_results.size() == 4, "Output tensor size must be 4.");
FDTensor& num_tensor = infer_results.at(0); // INT32
FDTensor& boxes_tensor = infer_results.at(1); // FLOAT
FDTensor& scores_tensor = infer_results.at(2); // FLOAT
FDTensor& classes_tensor = infer_results.at(3); // INT32
FDASSERT(num_tensor.dtype == FDDataType::INT32,
"The dtype of num_dets must be INT32.");
FDASSERT(boxes_tensor.dtype == FDDataType::FP32,
"The dtype of det_boxes_tensor must be FP32.");
FDASSERT(scores_tensor.dtype == FDDataType::FP32,
"The dtype of det_scores_tensor must be FP32.");
FDASSERT(classes_tensor.dtype == FDDataType::INT32,
"The dtype of det_classes_tensor must be INT32.");
FDASSERT(num_tensor.shape[0] == 1, "Only support batch=1 now.");
// post-process for end2end yolov7 after trt nms.
float* boxes_data = static_cast<float*>(boxes_tensor.Data()); // (1,100,4)
float* scores_data = static_cast<float*>(scores_tensor.Data()); // (1,100)
int32_t* classes_data =
static_cast<int32_t*>(classes_tensor.Data()); // (1,100)
int32_t num_dets_after_trt_nms = static_cast<int32_t*>(num_tensor.Data())[0];
if (num_dets_after_trt_nms == 0) {
return true;
}
result->Clear();
result->Reserve(num_dets_after_trt_nms);
for (size_t i = 0; i < num_dets_after_trt_nms; ++i) {
float confidence = scores_data[i];
if (confidence <= conf_threshold) {
continue;
}
int32_t label_id = classes_data[i];
float x1 = boxes_data[(i * 4) + 0];
float y1 = boxes_data[(i * 4) + 1];
float x2 = boxes_data[(i * 4) + 2];
float y2 = boxes_data[(i * 4) + 3];
result->boxes.emplace_back(std::array<float, 4>{x1, y1, x2, y2});
result->label_ids.push_back(label_id);
result->scores.push_back(confidence);
}
if (result->boxes.size() == 0) {
return true;
}
// scale the boxes to the origin image shape
auto iter_out = im_info.find("output_shape");
auto iter_ipt = im_info.find("input_shape");
FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(),
"Cannot find input_shape or output_shape from im_info.");
float out_h = iter_out->second[0];
float out_w = iter_out->second[1];
float ipt_h = iter_ipt->second[0];
float ipt_w = iter_ipt->second[1];
float scale = std::min(out_h / ipt_h, out_w / ipt_w);
float pad_h = (out_h - ipt_h * scale) / 2.0f;
float pad_w = (out_w - ipt_w * scale) / 2.0f;
if (is_mini_pad) {
pad_h = static_cast<float>(static_cast<int>(pad_h) % stride);
pad_w = static_cast<float>(static_cast<int>(pad_w) % stride);
}
for (size_t i = 0; i < result->boxes.size(); ++i) {
int32_t label_id = (result->label_ids)[i];
result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f);
result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f);
result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f);
result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f);
result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f);
result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f);
result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f);
result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f);
}
return true;
}
bool YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold) {
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
std::map<std::string, std::array<float, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors, result, im_info, conf_threshold)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,93 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel {
public:
YOLOv7End2EndTRT(const std::string& model_file,
const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
virtual std::string ModelName() const { return "yolov7end2end_trt"; }
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
// conf_threshold 为后处理的参数
// nms_iou_threshold 为后处理的参数
virtual bool Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold = 0.25);
// 以下为模型在预测时的一些参数,基本是前后处理所需
// 用户在创建模型后,可根据模型的要求,以及自己的需求
// 对参数进行修改
// tuple of (width, height)
std::vector<int> size;
// padding value, size should be same with Channels
std::vector<float> padding_value;
// only pad to the minimum rectange which height and width is times of stride
bool is_mini_pad;
// while is_mini_pad = false and is_no_pad = true, will resize the image to
// the set size
bool is_no_pad;
// if is_scale_up is false, the input image only can be zoom out, the maximum
// resize scale cannot exceed 1.0
bool is_scale_up;
// padding stride, for is_mini_pad
int stride;
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
// im_info为预处理过程保存的数据在后处理中需要用到
bool Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
bool Postprocess(std::vector<FDTensor>& infer_results,
DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold);
// 对图片进行LetterBox处理
// mat 为读取到的原图
// size 为输入模型的图像尺寸
void LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
bool is_dynamic_input_;
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindYOLOv7End2EndTRT(pybind11::module& m) {
pybind11::class_<vision::detection::YOLOv7End2EndTRT, FastDeployModel>(
m, "YOLOv7End2EndTRT")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::detection::YOLOv7End2EndTRT& self, pybind11::array& data,
float conf_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold);
return res;
})
.def_readwrite("size", &vision::detection::YOLOv7End2EndTRT::size)
.def_readwrite("padding_value",
&vision::detection::YOLOv7End2EndTRT::padding_value)
.def_readwrite("is_mini_pad",
&vision::detection::YOLOv7End2EndTRT::is_mini_pad)
.def_readwrite("is_no_pad",
&vision::detection::YOLOv7End2EndTRT::is_no_pad)
.def_readwrite("is_scale_up",
&vision::detection::YOLOv7End2EndTRT::is_scale_up)
.def_readwrite("stride", &vision::detection::YOLOv7End2EndTRT::stride);
}
} // namespace fastdeploy

View File

@@ -25,6 +25,7 @@ void BindYOLOv5(pybind11::module& m);
void BindYOLOX(pybind11::module& m); void BindYOLOX(pybind11::module& m);
void BindNanoDetPlus(pybind11::module& m); void BindNanoDetPlus(pybind11::module& m);
void BindPPDet(pybind11::module& m); void BindPPDet(pybind11::module& m);
void BindYOLOv7End2EndTRT(pybind11::module& m);
void BindYOLOv7End2EndORT(pybind11::module& m); void BindYOLOv7End2EndORT(pybind11::module& m);
void BindDetection(pybind11::module& m) { void BindDetection(pybind11::module& m) {
@@ -39,6 +40,7 @@ void BindDetection(pybind11::module& m) {
BindYOLOv5(detection_module); BindYOLOv5(detection_module);
BindYOLOX(detection_module); BindYOLOX(detection_module);
BindNanoDetPlus(detection_module); BindNanoDetPlus(detection_module);
BindYOLOv7End2EndTRT(detection_module);
BindYOLOv7End2EndORT(detection_module); BindYOLOv7End2EndORT(detection_module);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,43 @@
# YOLOv7End2EndTRT 准备部署模型
YOLOv7End2EndTRT 部署实现来自[YOLOv7](https://github.com/WongKinYiu/yolov7/tree/v0.1)分支代码,和[基于COCO的预训练模型](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)。注意YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型不带nms的模型推理请使用YOLOv7类而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
- 1[官方库](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署;*.trt和*.pose模型不支持部署
- 2自己数据训练的YOLOv7模型按照[导出ONNX模型](#%E5%AF%BC%E5%87%BAONNX%E6%A8%A1%E5%9E%8B)操作后,参考[详细部署文档](#详细部署文档)完成部署。
## 导出ONNX模型
```bash
# 下载yolov7模型文件
wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt
# 导出带TRT_NMS的onnx格式文件 (Tips: 对应 YOLOv7 release v0.1 代码)
python export.py --weights yolov7.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
# 导出其他模型的命令类似 将yolov7.pt替换成 yolov7x.pt yolov7-d6.pt yolov7-w6.pt ...
# 使用YOLOv7End2EndTRT只需提供onnx文件不需要额外再转trt文件推理时自动转换
```
## 下载预训练ONNX模型
为了方便开发者的测试下面提供了YOLOv7End2EndTRT 导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
| 模型 | 大小 | 精度 |
|:---------------------------------------------------------------- |:----- |:----- |
| [yolov7-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx) | 141MB | 51.4% |
| [yolov7x-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7x-end2end-trt-nms.onnx) | 273MB | 53.1% |
| [yolov7-w6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-w6-end2end-trt-nms.onnx) | 269MB | 54.9% |
| [yolov7-e6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6-end2end-trt-nms.onnx) | 372MB | 56.0% |
| [yolov7-d6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-d6-end2end-trt-nms.onnx) | 511MB | 56.6% |
| [yolov7-e6e-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6e-end2end-trt-nms.onnx) | 579MB | 56.8% |
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)
## 版本说明
- 本版本文档和代码基于[YOLOv7 0.1](https://github.com/WongKinYiu/yolov7/tree/v0.1) 编写

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,88 @@
# YOLOv7End2EndTRT C++部署示例
本目录下提供`infer.cc`快速完成GPU上通过TensorRT加速部署的示例该类只支持TensorRT部署。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/quick_start)
以Linux上推理为例在本目录执行如下命令即可完成编译测试
```bash
mkdir build
cd build
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
tar xvf fastdeploy-linux-x64-0.2.0.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-gpu-0.2.0
make -j
# 若预编译库没有支持该类 则请自行从源码develop分支编译最新的FastDeploy C++ SDK
#下载官方转换好的yolov7模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# TensorRT GPU推理
./infer_demo yolov7-end2end-trt-nms.onnx 000000014439.jpg 2
```
运行完成可视化结果如下图所示
<div align='center'>
<img width="640" alt="image" src="https://user-images.githubusercontent.com/31974251/186605967-ad0c53f2-3ce8-4032-a90f-6f5c1238e7f4.png">
</div>
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md)
注意YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型不带nms的模型推理请使用YOLOv7类而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
## YOLOv7End2EndTRT C++接口
### YOLOv7End2EndTRT 类
```c++
fastdeploy::vision::detection::YOLOv7End2EndTRT(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX)
```
YOLOv7End2EndTRT 模型加载和初始化其中model_file为导出的ONNX模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为ONNX格式
#### Predict函数
> ```c++
> YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result,
> float conf_threshold = 0.25)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
> > * **conf_threshold**: 检测框置信度过滤阈值但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值因此该参数只有在大于已经指定的阈值时才会有效。
### 类成员变量
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **size**(vector&lt;int&gt;): 通过此参数修改预处理过程中resize的大小包含两个整型元素表示[width, height], 默认值为[640, 640]
> > * **padding_value**(vector&lt;float&gt;): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114]
> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=ture` 表示不使用填充的方式,默认值为`is_no_pad=false`
> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=false`
> > * **stride**(int): 配合`stris_mini_pad`成员变量使用, 默认值为`stride=32`
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,110 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
void CpuInfer(const std::string& model_file, const std::string& image_file) {
auto model = fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model =
fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void TrtInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("images", {1, 3, 640, 640});
auto model =
fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_model ./yolov7-end2end-trt-nms.onnx ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 2) {
TrtInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -0,0 +1,80 @@
# YOLOv7End2EndTRT Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start)
本目录下提供`infer.py`快速完成YOLOv7End2EndTRT在TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/detection/yolov7end2end_trt/python/
#下载yolov7模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
# TensorRT GPU推理
python infer.py --model yolov7-end2end-trt-nms.onnx --image 000000014439.jpg --device gpu --use_trt True
# 若安装的python包没有支持该类 则请自行从源码develop分支编译最新的FastDeploy Python Wheel包进行安装
```
运行完成可视化结果如下图所示
<div align='center'>
<img width="640" alt="image" src="https://user-images.githubusercontent.com/31974251/186605967-ad0c53f2-3ce8-4032-a90f-6f5c1238e7f4.png">
</div>
注意YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型不带nms的模型推理请使用YOLOv7类而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
## YOLOv7End2EndTRT Python接口
```python
fastdeploy.vision.detection.YOLOv7End2EndTRT(model_file, params_file=None, runtime_option=None, model_format=Frontend.ONNX)
```
YOLOv7End2EndTRT 模型加载和初始化其中model_file为导出的ONNX模型格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式默认为ONNX
### predict函数
> ```python
> YOLOv7End2EndTRT.predict(image_data, conf_threshold=0.25)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值因此该参数只有在大于已经指定的阈值时才会有效。
> **返回**
>
> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
### 类成员属性
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **size**(list[int]): 通过此参数修改预处理过程中resize的大小包含两个整型元素表示[width, height], 默认值为[640, 640]
> > * **padding_value**(list[float]): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114]
> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=True` 表示不使用填充的方式,默认值为`is_no_pad=False`
> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=False`
> > * **stride**(int): 配合`stris_mini_padide`成员变量使用, 默认值为`stride=32`
## 其它文档
- [YOLOv7End2EndTRT 模型介绍](..)
- [YOLOv7End2EndTRT C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,53 @@
import fastdeploy as fd
import cv2
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of yolov7 end2end onnx model.")
parser.add_argument(
"--image", required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("images", [1, 3, 640, 640])
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.detection.YOLOv7End2EndTRT(
args.model, runtime_option=runtime_option)
# 预测图片检测结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 预测结果可视化
vis_im = fd.vision.vis_detection(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -21,5 +21,6 @@ from .contrib.yolox import YOLOX
from .contrib.yolov5 import YOLOv5 from .contrib.yolov5 import YOLOv5
from .contrib.yolov5lite import YOLOv5Lite from .contrib.yolov5lite import YOLOv5Lite
from .contrib.yolov6 import YOLOv6 from .contrib.yolov6 import YOLOv6
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3 from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, Frontend
from .... import c_lib_wrap as C
class YOLOv7End2EndTRT(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(YOLOv7End2EndTRT, self).__init__(runtime_option)
self._model = C.vision.detection.YOLOv7End2EndTRT(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "YOLOv7End2EndTRT initialize failed."
def predict(self, input_image, conf_threshold=0.25):
return self._model.predict(input_image, conf_threshold)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [1280, 1280]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@size.setter
def size(self, wh):
assert isinstance(wh, (list, tuple)),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value