Files
FastDeploy/fastdeploy/vision/detection/contrib/yolov7.h
Wang Xinyu c8d6c8244e [Model] Yolov5/v5lite/v6/v7/v7end2end: CUDA preprocessing (#370)
* add yolo cuda preprocessing

* cmake build cuda src

* yolov5 support cuda preprocessing

* yolov5 cuda preprocessing configurable

* yolov5 update get mat data api

* yolov5 check cuda preprocess args

* refactor cuda function name

* yolo cuda preprocess padding value configurable

* yolov5 release cuda memory

* cuda preprocess pybind api update

* move use_cuda_preprocessing option to yolov5 model

* yolov5lite cuda preprocessing

* yolov6 cuda preprocessing

* yolov7 cuda preprocessing

* yolov7_e2e cuda preprocessing

* remove cuda preprocessing in runtime option

* refine log and cmake variable name

* fix model runtime ptr type

Co-authored-by: Jason <jiangjiajun@baidu.com>
2022-10-19 16:04:58 +08:00

93 lines
3.4 KiB
C++

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace detection {
class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
public:
YOLOv7(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX);
~YOLOv7();
virtual std::string ModelName() const { return "yolov7"; }
virtual bool Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold = 0.25,
float nms_iou_threshold = 0.5);
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
// tuple of (width, height)
std::vector<int> size;
// padding value, size should be same with Channels
std::vector<float> padding_value;
// only pad to the minimum rectange which height and width is times of stride
bool is_mini_pad;
// while is_mini_pad = false and is_no_pad = true, will resize the image to
// the set size
bool is_no_pad;
// if is_scale_up is false, the input image only can be zoom out, the maximum
// resize scale cannot exceed 1.0
bool is_scale_up;
// padding stride, for is_mini_pad
int stride;
// for offseting the boxes by classes when using NMS
float max_wh;
private:
bool Initialize();
bool Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info);
bool CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info);
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold);
void LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
// CUDA host buffer for input image
uint8_t* input_img_cuda_buffer_host_ = nullptr;
// CUDA device buffer for input image
uint8_t* input_img_cuda_buffer_device_ = nullptr;
// CUDA device buffer for TRT input tensor
float* input_tensor_cuda_buffer_device_ = nullptr;
// Whether to use CUDA preprocessing
bool use_cuda_preprocessing_ = false;
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy