[Model] Yolov5/v5lite/v6/v7/v7end2end: CUDA preprocessing (#370)

* add yolo cuda preprocessing

* cmake build cuda src

* yolov5 support cuda preprocessing

* yolov5 cuda preprocessing configurable

* yolov5 update get mat data api

* yolov5 check cuda preprocess args

* refactor cuda function name

* yolo cuda preprocess padding value configurable

* yolov5 release cuda memory

* cuda preprocess pybind api update

* move use_cuda_preprocessing option to yolov5 model

* yolov5lite cuda preprocessing

* yolov6 cuda preprocessing

* yolov7 cuda preprocessing

* yolov7_e2e cuda preprocessing

* remove cuda preprocessing in runtime option

* refine log and cmake variable name

* fix model runtime ptr type

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
Wang Xinyu
2022-10-19 16:04:58 +08:00
committed by GitHub
parent 4b3e93223f
commit c8d6c8244e
26 changed files with 752 additions and 24 deletions

View File

@@ -16,6 +16,9 @@
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
#ifdef ENABLE_CUDA_PREPROCESS
#include "fastdeploy/vision/utils/cuda_utils.h"
#endif // ENABLE_CUDA_PREPROCESS
namespace fastdeploy {
namespace vision {
@@ -104,9 +107,20 @@ bool YOLOv5::Initialize() {
// if (!is_dynamic_input_) {
// is_mini_pad_ = false;
// }
return true;
}
YOLOv5::~YOLOv5() {
#ifdef ENABLE_CUDA_PREPROCESS
if (use_cuda_preprocessing_) {
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
}
#endif // ENABLE_CUDA_PREPROCESS
}
bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
@@ -156,6 +170,69 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
return true;
}
void YOLOv5::UseCudaPreprocessing(int max_image_size) {
#ifdef ENABLE_CUDA_PREPROCESS
use_cuda_preprocessing_ = true;
is_scale_up_ = true;
if (input_img_cuda_buffer_host_ == nullptr) {
// prepare input data cache in GPU pinned memory
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
// prepare input data cache in GPU device memory
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size_[0] * size_[1] * sizeof(float)));
}
#else
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
<< std::endl;
use_cuda_preprocessing_ = false;
#endif
}
bool YOLOv5::CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
const std::vector<float> padding_value,
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
int stride, float max_wh, bool multi_label) {
#ifdef ENABLE_CUDA_PREPROCESS
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
return false;
}
// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
input_img_cuda_buffer_host_,
src_img_buf_size, cudaMemcpyHostToDevice, stream));
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
mat->Height(), input_tensor_cuda_buffer_device_,
size[0], size[1], padding_value, stream);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);
// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
input_tensor_cuda_buffer_device_);
output->device = Device::GPU;
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
#else
FDERROR << "CUDA src code was not enabled." << std::endl;
return false;
#endif // ENABLE_CUDA_PREPROCESS
}
bool YOLOv5::Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
@@ -262,11 +339,20 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
std::map<std::string, std::array<float, 2>> im_info;
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
if (use_cuda_preprocessing_) {
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
} else {
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
}
input_tensors[0].name = InputInfoOfRuntime(0).name;