mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[feature][vision] Add YOLOv7 End2End model with TRT NMS (#157)
* [feature][vision] Add YOLOv7 End2End model with TRT NMS * [docs] update yolov7end2end_trt examples docs Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -27,7 +27,9 @@ bool FastDeployModel::InitRuntime() {
|
||||
}
|
||||
if (runtime_option.backend != Backend::UNKNOWN) {
|
||||
if (!IsBackendAvailable(runtime_option.backend)) {
|
||||
FDERROR << Str(runtime_option.backend) << " is not compiled with current FastDeploy library." << std::endl;
|
||||
FDERROR << Str(runtime_option.backend)
|
||||
<< " is not compiled with current FastDeploy library."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -70,7 +72,7 @@ bool FastDeployModel::InitRuntime() {
|
||||
FDWARNING << "FastDeploy will choose " << Str(valid_gpu_backends[0])
|
||||
<< " for model inference." << std::endl;
|
||||
} else {
|
||||
FDASSERT(valid_gpu_backends.size() > 0,
|
||||
FDASSERT(valid_cpu_backends.size() > 0,
|
||||
"There's no valid cpu backend for %s.", ModelName().c_str());
|
||||
FDWARNING << "FastDeploy will choose " << Str(valid_cpu_backends[0])
|
||||
<< " for model inference." << std::endl;
|
||||
|
@@ -23,6 +23,7 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov6.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h"
|
||||
#include "fastdeploy/vision/detection/contrib/yolox.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/model.h"
|
||||
|
267
csrc/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
Normal file
267
csrc/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
Normal file
@@ -0,0 +1,267 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
void YOLOv7End2EndTRT::LetterBox(Mat* mat, const std::vector<int>& size,
|
||||
const std::vector<float>& color, bool _auto,
|
||||
bool scale_fill, bool scale_up, int stride) {
|
||||
float scale =
|
||||
std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width());
|
||||
if (!scale_up) {
|
||||
scale = std::min(scale, 1.0f);
|
||||
}
|
||||
|
||||
int resize_h = int(round(mat->Height() * scale));
|
||||
int resize_w = int(round(mat->Width() * scale));
|
||||
|
||||
int pad_w = size[0] - resize_w;
|
||||
int pad_h = size[1] - resize_h;
|
||||
if (_auto) {
|
||||
pad_h = pad_h % stride;
|
||||
pad_w = pad_w % stride;
|
||||
} else if (scale_fill) {
|
||||
pad_h = 0;
|
||||
pad_w = 0;
|
||||
resize_h = size[1];
|
||||
resize_w = size[0];
|
||||
}
|
||||
if (resize_h != mat->Height() || resize_w != mat->Width()) {
|
||||
Resize::Run(mat, resize_w, resize_h);
|
||||
}
|
||||
if (pad_h > 0 || pad_w > 0) {
|
||||
float half_h = pad_h * 1.0 / 2;
|
||||
int top = int(round(half_h - 0.1));
|
||||
int bottom = int(round(half_h + 0.1));
|
||||
float half_w = pad_w * 1.0 / 2;
|
||||
int left = int(round(half_w - 0.1));
|
||||
int right = int(round(half_w + 0.1));
|
||||
Pad::Run(mat, top, bottom, left, right, color);
|
||||
}
|
||||
}
|
||||
|
||||
YOLOv7End2EndTRT::YOLOv7End2EndTRT(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const Frontend& model_format) {
|
||||
if (model_format == Frontend::ONNX) {
|
||||
valid_cpu_backends = {}; // NO CPU
|
||||
valid_gpu_backends = {Backend::TRT}; // NO ORT
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
}
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
if (runtime_option.device != Device::GPU) {
|
||||
FDWARNING << Str(runtime_option.device)
|
||||
<< " is not support for YOLOv7End2EndTRT,"
|
||||
<< "will fallback to Device::GPU." << std::endl;
|
||||
runtime_option.device = Device::GPU;
|
||||
}
|
||||
if (runtime_option.backend != Backend::UNKNOWN) {
|
||||
if (runtime_option.backend != Backend::TRT) {
|
||||
FDWARNING << Str(runtime_option.backend)
|
||||
<< " is not support for YOLOv7End2EndTRT,"
|
||||
<< "will fallback to Backend::TRT." << std::endl;
|
||||
runtime_option.backend = Backend::TRT;
|
||||
}
|
||||
}
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Initialize() {
|
||||
// parameters for preprocess
|
||||
size = {640, 640};
|
||||
padding_value = {114.0, 114.0, 114.0};
|
||||
is_mini_pad = false;
|
||||
is_no_pad = false;
|
||||
is_scale_up = false;
|
||||
stride = 32;
|
||||
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
// Check if the input shape is dynamic after Runtime already initialized,
|
||||
// Note that, We need to force is_mini_pad 'false' to keep static
|
||||
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
|
||||
is_dynamic_input_ = false;
|
||||
auto shape = InputInfoOfRuntime(0).shape;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
// if height or width is dynamic
|
||||
if (i >= 2 && shape[i] <= 0) {
|
||||
is_dynamic_input_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_dynamic_input_) {
|
||||
is_mini_pad = false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Preprocess(
|
||||
Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
float ratio = std::min(size[1] * 1.0f / static_cast<float>(mat->Height()),
|
||||
size[0] * 1.0f / static_cast<float>(mat->Width()));
|
||||
if (ratio != 1.0) {
|
||||
int interp = cv::INTER_AREA;
|
||||
if (ratio > 1.0) {
|
||||
interp = cv::INTER_LINEAR;
|
||||
}
|
||||
int resize_h = int(mat->Height() * ratio);
|
||||
int resize_w = int(mat->Width() * ratio);
|
||||
Resize::Run(mat, resize_w, resize_h, -1, -1, interp);
|
||||
}
|
||||
YOLOv7End2EndTRT::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad,
|
||||
is_scale_up, stride);
|
||||
BGR2RGB::Run(mat);
|
||||
std::vector<float> alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
|
||||
std::vector<float> beta = {0.0f, 0.0f, 0.0f};
|
||||
Convert::Run(mat, alpha, beta);
|
||||
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
HWC2CHW::Run(mat);
|
||||
Cast::Run(mat, "float");
|
||||
mat->ShareWithTensor(output);
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Postprocess(
|
||||
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
float conf_threshold) {
|
||||
FDASSERT(infer_results.size() == 4, "Output tensor size must be 4.");
|
||||
FDTensor& num_tensor = infer_results.at(0); // INT32
|
||||
FDTensor& boxes_tensor = infer_results.at(1); // FLOAT
|
||||
FDTensor& scores_tensor = infer_results.at(2); // FLOAT
|
||||
FDTensor& classes_tensor = infer_results.at(3); // INT32
|
||||
FDASSERT(num_tensor.dtype == FDDataType::INT32,
|
||||
"The dtype of num_dets must be INT32.");
|
||||
FDASSERT(boxes_tensor.dtype == FDDataType::FP32,
|
||||
"The dtype of det_boxes_tensor must be FP32.");
|
||||
FDASSERT(scores_tensor.dtype == FDDataType::FP32,
|
||||
"The dtype of det_scores_tensor must be FP32.");
|
||||
FDASSERT(classes_tensor.dtype == FDDataType::INT32,
|
||||
"The dtype of det_classes_tensor must be INT32.");
|
||||
FDASSERT(num_tensor.shape[0] == 1, "Only support batch=1 now.");
|
||||
// post-process for end2end yolov7 after trt nms.
|
||||
float* boxes_data = static_cast<float*>(boxes_tensor.Data()); // (1,100,4)
|
||||
float* scores_data = static_cast<float*>(scores_tensor.Data()); // (1,100)
|
||||
int32_t* classes_data =
|
||||
static_cast<int32_t*>(classes_tensor.Data()); // (1,100)
|
||||
int32_t num_dets_after_trt_nms = static_cast<int32_t*>(num_tensor.Data())[0];
|
||||
if (num_dets_after_trt_nms == 0) {
|
||||
return true;
|
||||
}
|
||||
result->Clear();
|
||||
result->Reserve(num_dets_after_trt_nms);
|
||||
for (size_t i = 0; i < num_dets_after_trt_nms; ++i) {
|
||||
float confidence = scores_data[i];
|
||||
if (confidence <= conf_threshold) {
|
||||
continue;
|
||||
}
|
||||
int32_t label_id = classes_data[i];
|
||||
float x1 = boxes_data[(i * 4) + 0];
|
||||
float y1 = boxes_data[(i * 4) + 1];
|
||||
float x2 = boxes_data[(i * 4) + 2];
|
||||
float y2 = boxes_data[(i * 4) + 3];
|
||||
|
||||
result->boxes.emplace_back(std::array<float, 4>{x1, y1, x2, y2});
|
||||
result->label_ids.push_back(label_id);
|
||||
result->scores.push_back(confidence);
|
||||
}
|
||||
|
||||
if (result->boxes.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// scale the boxes to the origin image shape
|
||||
auto iter_out = im_info.find("output_shape");
|
||||
auto iter_ipt = im_info.find("input_shape");
|
||||
FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(),
|
||||
"Cannot find input_shape or output_shape from im_info.");
|
||||
float out_h = iter_out->second[0];
|
||||
float out_w = iter_out->second[1];
|
||||
float ipt_h = iter_ipt->second[0];
|
||||
float ipt_w = iter_ipt->second[1];
|
||||
float scale = std::min(out_h / ipt_h, out_w / ipt_w);
|
||||
float pad_h = (out_h - ipt_h * scale) / 2.0f;
|
||||
float pad_w = (out_w - ipt_w * scale) / 2.0f;
|
||||
if (is_mini_pad) {
|
||||
pad_h = static_cast<float>(static_cast<int>(pad_h) % stride);
|
||||
pad_w = static_cast<float>(static_cast<int>(pad_w) % stride);
|
||||
}
|
||||
for (size_t i = 0; i < result->boxes.size(); ++i) {
|
||||
int32_t label_id = (result->label_ids)[i];
|
||||
result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f);
|
||||
result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f);
|
||||
result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f);
|
||||
result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f);
|
||||
result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f);
|
||||
result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f);
|
||||
result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f);
|
||||
result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold) {
|
||||
Mat mat(*im);
|
||||
std::vector<FDTensor> input_tensors(1);
|
||||
|
||||
std::map<std::string, std::array<float, 2>> im_info;
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
im_info["input_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
std::vector<FDTensor> output_tensors;
|
||||
if (!Infer(input_tensors, &output_tensors)) {
|
||||
FDERROR << "Failed to inference." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Postprocess(output_tensors, result, im_info, conf_threshold)) {
|
||||
FDERROR << "Failed to post process." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
93
csrc/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h
Normal file
93
csrc/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace detection {
|
||||
|
||||
class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel {
|
||||
public:
|
||||
YOLOv7End2EndTRT(const std::string& model_file,
|
||||
const std::string& params_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::ONNX);
|
||||
|
||||
// 定义模型的名称
|
||||
virtual std::string ModelName() const { return "yolov7end2end_trt"; }
|
||||
|
||||
// 模型预测接口,即用户调用的接口
|
||||
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
|
||||
// result 为模型预测的输出结构体
|
||||
// conf_threshold 为后处理的参数
|
||||
// nms_iou_threshold 为后处理的参数
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold = 0.25);
|
||||
|
||||
// 以下为模型在预测时的一些参数,基本是前后处理所需
|
||||
// 用户在创建模型后,可根据模型的要求,以及自己的需求
|
||||
// 对参数进行修改
|
||||
// tuple of (width, height)
|
||||
std::vector<int> size;
|
||||
// padding value, size should be same with Channels
|
||||
std::vector<float> padding_value;
|
||||
// only pad to the minimum rectange which height and width is times of stride
|
||||
bool is_mini_pad;
|
||||
// while is_mini_pad = false and is_no_pad = true, will resize the image to
|
||||
// the set size
|
||||
bool is_no_pad;
|
||||
// if is_scale_up is false, the input image only can be zoom out, the maximum
|
||||
// resize scale cannot exceed 1.0
|
||||
bool is_scale_up;
|
||||
// padding stride, for is_mini_pad
|
||||
int stride;
|
||||
|
||||
private:
|
||||
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
|
||||
bool Initialize();
|
||||
|
||||
// 输入图像预处理操作
|
||||
// Mat为FastDeploy定义的数据结构
|
||||
// FDTensor为预处理后的Tensor数据,传给后端进行推理
|
||||
// im_info为预处理过程保存的数据,在后处理中需要用到
|
||||
bool Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
// 后端推理结果后处理,输出给用户
|
||||
// infer_result 为后端推理后的输出Tensor
|
||||
// result 为模型预测的结果
|
||||
// im_info 为预处理记录的信息,后处理用于还原box
|
||||
// conf_threshold 后处理时过滤box的置信度阈值
|
||||
bool Postprocess(std::vector<FDTensor>& infer_results,
|
||||
DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
float conf_threshold);
|
||||
|
||||
// 对图片进行LetterBox处理
|
||||
// mat 为读取到的原图
|
||||
// size 为输入模型的图像尺寸
|
||||
void LetterBox(Mat* mat, const std::vector<int>& size,
|
||||
const std::vector<float>& color, bool _auto,
|
||||
bool scale_fill = false, bool scale_up = true,
|
||||
int stride = 32);
|
||||
|
||||
bool is_dynamic_input_;
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindYOLOv7End2EndTRT(pybind11::module& m) {
|
||||
pybind11::class_<vision::detection::YOLOv7End2EndTRT, FastDeployModel>(
|
||||
m, "YOLOv7End2EndTRT")
|
||||
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
|
||||
.def("predict",
|
||||
[](vision::detection::YOLOv7End2EndTRT& self, pybind11::array& data,
|
||||
float conf_threshold) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::DetectionResult res;
|
||||
self.Predict(&mat, &res, conf_threshold);
|
||||
return res;
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv7End2EndTRT::size)
|
||||
.def_readwrite("padding_value",
|
||||
&vision::detection::YOLOv7End2EndTRT::padding_value)
|
||||
.def_readwrite("is_mini_pad",
|
||||
&vision::detection::YOLOv7End2EndTRT::is_mini_pad)
|
||||
.def_readwrite("is_no_pad",
|
||||
&vision::detection::YOLOv7End2EndTRT::is_no_pad)
|
||||
.def_readwrite("is_scale_up",
|
||||
&vision::detection::YOLOv7End2EndTRT::is_scale_up)
|
||||
.def_readwrite("stride", &vision::detection::YOLOv7End2EndTRT::stride);
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -25,6 +25,7 @@ void BindYOLOv5(pybind11::module& m);
|
||||
void BindYOLOX(pybind11::module& m);
|
||||
void BindNanoDetPlus(pybind11::module& m);
|
||||
void BindPPDet(pybind11::module& m);
|
||||
void BindYOLOv7End2EndTRT(pybind11::module& m);
|
||||
void BindYOLOv7End2EndORT(pybind11::module& m);
|
||||
|
||||
void BindDetection(pybind11::module& m) {
|
||||
@@ -39,6 +40,7 @@ void BindDetection(pybind11::module& m) {
|
||||
BindYOLOv5(detection_module);
|
||||
BindYOLOX(detection_module);
|
||||
BindNanoDetPlus(detection_module);
|
||||
BindYOLOv7End2EndTRT(detection_module);
|
||||
BindYOLOv7End2EndORT(detection_module);
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
43
examples/vision/detection/yolov7end2end_trt/README.md
Normal file
43
examples/vision/detection/yolov7end2end_trt/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# YOLOv7End2EndTRT 准备部署模型
|
||||
|
||||
YOLOv7End2EndTRT 部署实现来自[YOLOv7](https://github.com/WongKinYiu/yolov7/tree/v0.1)分支代码,和[基于COCO的预训练模型](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)。注意,YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
|
||||
|
||||
- (1)[官方库](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署;*.trt和*.pose模型不支持部署;
|
||||
- (2)自己数据训练的YOLOv7模型,按照[导出ONNX模型](#%E5%AF%BC%E5%87%BAONNX%E6%A8%A1%E5%9E%8B)操作后,参考[详细部署文档](#详细部署文档)完成部署。
|
||||
|
||||
|
||||
|
||||
## 导出ONNX模型
|
||||
|
||||
```bash
|
||||
# 下载yolov7模型文件
|
||||
wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt
|
||||
|
||||
# 导出带TRT_NMS的onnx格式文件 (Tips: 对应 YOLOv7 release v0.1 代码)
|
||||
python export.py --weights yolov7.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
|
||||
# 导出其他模型的命令类似 将yolov7.pt替换成 yolov7x.pt yolov7-d6.pt yolov7-w6.pt ...
|
||||
# 使用YOLOv7End2EndTRT只需提供onnx文件,不需要额外再转trt文件,推理时自动转换
|
||||
```
|
||||
|
||||
## 下载预训练ONNX模型
|
||||
|
||||
为了方便开发者的测试,下面提供了YOLOv7End2EndTRT 导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
|
||||
| 模型 | 大小 | 精度 |
|
||||
|:---------------------------------------------------------------- |:----- |:----- |
|
||||
| [yolov7-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx) | 141MB | 51.4% |
|
||||
| [yolov7x-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7x-end2end-trt-nms.onnx) | 273MB | 53.1% |
|
||||
| [yolov7-w6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-w6-end2end-trt-nms.onnx) | 269MB | 54.9% |
|
||||
| [yolov7-e6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6-end2end-trt-nms.onnx) | 372MB | 56.0% |
|
||||
| [yolov7-d6-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-d6-end2end-trt-nms.onnx) | 511MB | 56.6% |
|
||||
| [yolov7-e6e-end2end-trt-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6e-end2end-trt-nms.onnx) | 579MB | 56.8% |
|
||||
|
||||
|
||||
## 详细部署文档
|
||||
|
||||
- [Python部署](python)
|
||||
- [C++部署](cpp)
|
||||
|
||||
|
||||
## 版本说明
|
||||
|
||||
- 本版本文档和代码基于[YOLOv7 0.1](https://github.com/WongKinYiu/yolov7/tree/v0.1) 编写
|
@@ -0,0 +1,14 @@
|
||||
PROJECT(infer_demo C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
|
||||
|
||||
# 指定下载解压后的fastdeploy库路径
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
# 添加FastDeploy依赖头文件
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||
# 添加FastDeploy库依赖
|
||||
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
|
88
examples/vision/detection/yolov7end2end_trt/cpp/README.md
Normal file
88
examples/vision/detection/yolov7end2end_trt/cpp/README.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# YOLOv7End2EndTRT C++部署示例
|
||||
|
||||
本目录下提供`infer.cc`快速完成GPU上通过TensorRT加速部署的示例,该类只支持TensorRT部署。
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start)
|
||||
|
||||
以Linux上推理为例,在本目录执行如下命令即可完成编译测试
|
||||
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz
|
||||
tar xvf fastdeploy-linux-x64-0.2.0.tgz
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-gpu-0.2.0
|
||||
make -j
|
||||
# 若预编译库没有支持该类 则请自行从源码develop分支编译最新的FastDeploy C++ SDK
|
||||
|
||||
#下载官方转换好的yolov7模型文件和测试图片
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||
|
||||
# TensorRT GPU推理
|
||||
./infer_demo yolov7-end2end-trt-nms.onnx 000000014439.jpg 2
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<div align='center'>
|
||||
<img width="640" alt="image" src="https://user-images.githubusercontent.com/31974251/186605967-ad0c53f2-3ce8-4032-a90f-6f5c1238e7f4.png">
|
||||
</div>
|
||||
|
||||
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
|
||||
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md)
|
||||
|
||||
注意,YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
|
||||
|
||||
## YOLOv7End2EndTRT C++接口
|
||||
|
||||
### YOLOv7End2EndTRT 类
|
||||
|
||||
```c++
|
||||
fastdeploy::vision::detection::YOLOv7End2EndTRT(
|
||||
const string& model_file,
|
||||
const string& params_file = "",
|
||||
const RuntimeOption& runtime_option = RuntimeOption(),
|
||||
const Frontend& model_format = Frontend::ONNX)
|
||||
```
|
||||
|
||||
YOLOv7End2EndTRT 模型加载和初始化,其中model_file为导出的ONNX模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为ONNX格式
|
||||
|
||||
#### Predict函数
|
||||
|
||||
> ```c++
|
||||
> YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result,
|
||||
> float conf_threshold = 0.25)
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **im**: 输入图像,注意需为HWC,BGR格式
|
||||
> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
> > * **conf_threshold**: 检测框置信度过滤阈值,但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值,因此该参数只有在大于已经指定的阈值时才会有效。
|
||||
|
||||
### 类成员变量
|
||||
#### 预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **size**(vector<int>): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640]
|
||||
> > * **padding_value**(vector<float>): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114]
|
||||
> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=ture` 表示不使用填充的方式,默认值为`is_no_pad=false`
|
||||
> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=false`
|
||||
> > * **stride**(int): 配合`stris_mini_pad`成员变量使用, 默认值为`stride=32`
|
||||
|
||||
- [模型介绍](../../)
|
||||
- [Python部署](../python)
|
||||
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
|
110
examples/vision/detection/yolov7end2end_trt/cpp/infer.cc
Normal file
110
examples/vision/detection/yolov7end2end_trt/cpp/infer.cc
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
void CpuInfer(const std::string& model_file, const std::string& image_file) {
|
||||
auto model = fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& model_file, const std::string& image_file) {
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseGpu();
|
||||
auto model =
|
||||
fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file, "", option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void TrtInfer(const std::string& model_file, const std::string& image_file) {
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseGpu();
|
||||
option.UseTrtBackend();
|
||||
option.SetTrtInputShape("images", {1, 3, 640, 640});
|
||||
auto model =
|
||||
fastdeploy::vision::detection::YOLOv7End2EndTRT(model_file, "", option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
auto im_bak = im.clone();
|
||||
|
||||
fastdeploy::vision::DetectionResult res;
|
||||
if (!model.Predict(&im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 4) {
|
||||
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
|
||||
"e.g ./infer_model ./yolov7-end2end-trt-nms.onnx ./test.jpeg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu; 2: run with gpu and use tensorrt backend."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[3]) == 0) {
|
||||
CpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 1) {
|
||||
GpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 2) {
|
||||
TrtInfer(argv[1], argv[2]);
|
||||
}
|
||||
return 0;
|
||||
}
|
80
examples/vision/detection/yolov7end2end_trt/python/README.md
Normal file
80
examples/vision/detection/yolov7end2end_trt/python/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# YOLOv7End2EndTRT Python部署示例
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md)
|
||||
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start)
|
||||
|
||||
本目录下提供`infer.py`快速完成YOLOv7End2EndTRT在TensorRT加速部署的示例。执行如下脚本即可完成
|
||||
|
||||
```bash
|
||||
#下载部署示例代码
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd FastDeploy/examples/vision/detection/yolov7end2end_trt/python/
|
||||
|
||||
#下载yolov7模型文件和测试图片
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-trt-nms.onnx
|
||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
||||
|
||||
# TensorRT GPU推理
|
||||
python infer.py --model yolov7-end2end-trt-nms.onnx --image 000000014439.jpg --device gpu --use_trt True
|
||||
# 若安装的python包没有支持该类 则请自行从源码develop分支编译最新的FastDeploy Python Wheel包进行安装
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<div align='center'>
|
||||
<img width="640" alt="image" src="https://user-images.githubusercontent.com/31974251/186605967-ad0c53f2-3ce8-4032-a90f-6f5c1238e7f4.png">
|
||||
</div>
|
||||
|
||||
注意,YOLOv7End2EndTRT 是专门用于推理YOLOv7中导出模型带[TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型请使用YOLOv7End2EndORT进行推理。
|
||||
|
||||
## YOLOv7End2EndTRT Python接口
|
||||
|
||||
```python
|
||||
fastdeploy.vision.detection.YOLOv7End2EndTRT(model_file, params_file=None, runtime_option=None, model_format=Frontend.ONNX)
|
||||
```
|
||||
|
||||
YOLOv7End2EndTRT 模型加载和初始化,其中model_file为导出的ONNX模型格式
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(Frontend): 模型格式,默认为ONNX
|
||||
|
||||
### predict函数
|
||||
|
||||
> ```python
|
||||
> YOLOv7End2EndTRT.predict(image_data, conf_threshold=0.25)
|
||||
> ```
|
||||
>
|
||||
> 模型预测结口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
||||
> > * **conf_threshold**(float): 检测框置信度过滤阈值,但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值,因此该参数只有在大于已经指定的阈值时才会有效。
|
||||
|
||||
> **返回**
|
||||
>
|
||||
> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
### 类成员属性
|
||||
#### 预处理参数
|
||||
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
|
||||
|
||||
> > * **size**(list[int]): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640]
|
||||
> > * **padding_value**(list[float]): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114]
|
||||
> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=True` 表示不使用填充的方式,默认值为`is_no_pad=False`
|
||||
> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=False`
|
||||
> > * **stride**(int): 配合`stris_mini_padide`成员变量使用, 默认值为`stride=32`
|
||||
|
||||
|
||||
|
||||
## 其它文档
|
||||
|
||||
- [YOLOv7End2EndTRT 模型介绍](..)
|
||||
- [YOLOv7End2EndTRT C++部署](../cpp)
|
||||
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
53
examples/vision/detection/yolov7end2end_trt/python/infer.py
Normal file
53
examples/vision/detection/yolov7end2end_trt/python/infer.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", required=True, help="Path of yolov7 end2end onnx model.")
|
||||
parser.add_argument(
|
||||
"--image", required=True, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("images", [1, 3, 640, 640])
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = build_option(args)
|
||||
model = fd.vision.detection.YOLOv7End2EndTRT(
|
||||
args.model, runtime_option=runtime_option)
|
||||
|
||||
# 预测图片检测结果
|
||||
im = cv2.imread(args.image)
|
||||
result = model.predict(im.copy())
|
||||
print(result)
|
||||
|
||||
# 预测结果可视化
|
||||
vis_im = fd.vision.vis_detection(im, result)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
@@ -21,5 +21,6 @@ from .contrib.yolox import YOLOX
|
||||
from .contrib.yolov5 import YOLOv5
|
||||
from .contrib.yolov5lite import YOLOv5Lite
|
||||
from .contrib.yolov6 import YOLOv6
|
||||
from .contrib.yolov7end2end_trt import YOLOv7End2EndTRT
|
||||
from .contrib.yolov7end2end_ort import YOLOv7End2EndORT
|
||||
from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3
|
||||
|
105
fastdeploy/vision/detection/contrib/yolov7end2end_trt.py
Normal file
105
fastdeploy/vision/detection/contrib/yolov7end2end_trt.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from .... import FastDeployModel, Frontend
|
||||
from .... import c_lib_wrap as C
|
||||
|
||||
|
||||
class YOLOv7End2EndTRT(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file="",
|
||||
runtime_option=None,
|
||||
model_format=Frontend.ONNX):
|
||||
# 调用基函数进行backend_option的初始化
|
||||
# 初始化后的option保存在self._runtime_option
|
||||
super(YOLOv7End2EndTRT, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.detection.YOLOv7End2EndTRT(
|
||||
model_file, params_file, self._runtime_option, model_format)
|
||||
# 通过self.initialized判断整个模型的初始化是否成功
|
||||
assert self.initialized, "YOLOv7End2EndTRT initialize failed."
|
||||
|
||||
def predict(self, input_image, conf_threshold=0.25):
|
||||
return self._model.predict(input_image, conf_threshold)
|
||||
|
||||
# 一些跟模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
def size(self):
|
||||
return self._model.size
|
||||
|
||||
@property
|
||||
def padding_value(self):
|
||||
return self._model.padding_value
|
||||
|
||||
@property
|
||||
def is_no_pad(self):
|
||||
return self._model.is_no_pad
|
||||
|
||||
@property
|
||||
def is_mini_pad(self):
|
||||
return self._model.is_mini_pad
|
||||
|
||||
@property
|
||||
def is_scale_up(self):
|
||||
return self._model.is_scale_up
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._model.stride
|
||||
|
||||
@size.setter
|
||||
def size(self, wh):
|
||||
assert isinstance(wh, (list, tuple)),\
|
||||
"The value to set `size` must be type of tuple or list."
|
||||
assert len(wh) == 2,\
|
||||
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
|
||||
len(wh))
|
||||
self._model.size = wh
|
||||
|
||||
@padding_value.setter
|
||||
def padding_value(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
list), "The value to set `padding_value` must be type of list."
|
||||
self._model.padding_value = value
|
||||
|
||||
@is_no_pad.setter
|
||||
def is_no_pad(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `is_no_pad` must be type of bool."
|
||||
self._model.is_no_pad = value
|
||||
|
||||
@is_mini_pad.setter
|
||||
def is_mini_pad(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_mini_pad` must be type of bool."
|
||||
self._model.is_mini_pad = value
|
||||
|
||||
@is_scale_up.setter
|
||||
def is_scale_up(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_scale_up` must be type of bool."
|
||||
self._model.is_scale_up = value
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value):
|
||||
assert isinstance(
|
||||
value, int), "The value to set `stride` must be type of int."
|
||||
self._model.stride = value
|
Reference in New Issue
Block a user