mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] Yolov5/v5lite/v6/v7/v7end2end: CUDA preprocessing (#370)
* add yolo cuda preprocessing * cmake build cuda src * yolov5 support cuda preprocessing * yolov5 cuda preprocessing configurable * yolov5 update get mat data api * yolov5 check cuda preprocess args * refactor cuda function name * yolo cuda preprocess padding value configurable * yolov5 release cuda memory * cuda preprocess pybind api update * move use_cuda_preprocessing option to yolov5 model * yolov5lite cuda preprocessing * yolov6 cuda preprocessing * yolov7 cuda preprocessing * yolov7_e2e cuda preprocessing * remove cuda preprocessing in runtime option * refine log and cmake variable name * fix model runtime ptr type Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -92,6 +92,16 @@ if(BUILD_ON_JETSON)
|
||||
set(ENABLE_ORT_BACKEND ON)
|
||||
endif()
|
||||
|
||||
# Whether to build CUDA source files in fastdeploy
|
||||
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
||||
if(WITH_GPU AND UNIX)
|
||||
set(BUILD_CUDA_SRC ON)
|
||||
enable_language(CUDA)
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
|
||||
else()
|
||||
set(BUILD_CUDA_SRC OFF)
|
||||
endif()
|
||||
|
||||
# config GIT_URL with github mirrors to speed up dependent repos clone
|
||||
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
|
||||
if(NOT GIT_URL)
|
||||
@@ -174,6 +184,7 @@ file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastde
|
||||
file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/lite/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu)
|
||||
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
|
||||
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS})
|
||||
@@ -373,6 +384,10 @@ if(ENABLE_VISION)
|
||||
endif()
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp)
|
||||
list(APPEND DEPEND_LIBS yaml-cpp)
|
||||
if(BUILD_CUDA_SRC)
|
||||
add_definitions(-DENABLE_CUDA_PREPROCESS)
|
||||
list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS})
|
||||
endif()
|
||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
|
||||
include(${PROJECT_SOURCE_DIR}/cmake/opencv.cmake)
|
||||
@@ -428,7 +443,13 @@ elseif(ANDROID)
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL})
|
||||
elseif(MSVC)
|
||||
else()
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
|
||||
if(BUILD_CUDA_SRC)
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS
|
||||
"$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>")
|
||||
else()
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
|
||||
endif()
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL")
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_RELEASE -s)
|
||||
endif()
|
||||
|
@@ -51,6 +51,7 @@ function(fastdeploy_summary)
|
||||
message(STATUS " WITH_GPU : ${WITH_GPU}")
|
||||
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
|
||||
message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}")
|
||||
message(STATUS " BUILD_CUDA_SRC : ${BUILD_CUDA_SRC}")
|
||||
endif()
|
||||
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
|
||||
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
|
||||
|
@@ -57,7 +57,7 @@ bool FastDeployModel::InitRuntime() {
|
||||
}
|
||||
|
||||
if (is_supported) {
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
runtime_ = std::shared_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
@@ -107,7 +107,7 @@ bool FastDeployModel::CreateCpuBackend() {
|
||||
continue;
|
||||
}
|
||||
runtime_option.backend = valid_cpu_backends[i];
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
runtime_ = std::shared_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
@@ -130,7 +130,7 @@ bool FastDeployModel::CreateGpuBackend() {
|
||||
continue;
|
||||
}
|
||||
runtime_option.backend = valid_gpu_backends[i];
|
||||
runtime_ = std::unique_ptr<Runtime>(new Runtime());
|
||||
runtime_ = std::shared_ptr<Runtime>(new Runtime());
|
||||
if (!runtime_->Init(runtime_option)) {
|
||||
return false;
|
||||
}
|
||||
|
@@ -99,7 +99,7 @@ class FASTDEPLOY_DECL FastDeployModel {
|
||||
std::vector<Backend> valid_external_backends;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Runtime> runtime_;
|
||||
std::shared_ptr<Runtime> runtime_;
|
||||
bool runtime_initialized_ = false;
|
||||
// whether to record inference time
|
||||
bool enable_record_time_of_runtime_ = false;
|
||||
|
@@ -16,6 +16,9 @@
|
||||
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -104,9 +107,20 @@ bool YOLOv5::Initialize() {
|
||||
// if (!is_dynamic_input_) {
|
||||
// is_mini_pad_ = false;
|
||||
// }
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
YOLOv5::~YOLOv5() {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (use_cuda_preprocessing_) {
|
||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||
}
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info,
|
||||
const std::vector<int>& size,
|
||||
@@ -156,6 +170,69 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
|
||||
return true;
|
||||
}
|
||||
|
||||
void YOLOv5::UseCudaPreprocessing(int max_image_size) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
use_cuda_preprocessing_ = true;
|
||||
is_scale_up_ = true;
|
||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||
// prepare input data cache in GPU pinned memory
|
||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
||||
// prepare input data cache in GPU device memory
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size_[0] * size_[1] * sizeof(float)));
|
||||
}
|
||||
#else
|
||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
||||
<< std::endl;
|
||||
use_cuda_preprocessing_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool YOLOv5::CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info,
|
||||
const std::vector<int>& size,
|
||||
const std::vector<float> padding_value,
|
||||
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
|
||||
int stride, float max_wh, bool multi_label) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
cudaStream_t stream;
|
||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||
input_img_cuda_buffer_host_,
|
||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||
size[0], size[1], padding_value, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// Record output shape of preprocessed image
|
||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
||||
|
||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||
input_tensor_cuda_buffer_device_);
|
||||
output->device = Device::GPU;
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
#else
|
||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||
return false;
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv5::Postprocess(
|
||||
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -262,11 +339,20 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
|
||||
std::map<std::string, std::array<float, 2>> im_info;
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
|
||||
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
|
||||
multi_label_)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
if (use_cuda_preprocessing_) {
|
||||
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
|
||||
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
|
||||
multi_label_)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
|
||||
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
|
||||
multi_label_)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
@@ -27,6 +28,8 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::ONNX);
|
||||
|
||||
~YOLOv5();
|
||||
|
||||
std::string ModelName() const { return "yolov5"; }
|
||||
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
@@ -42,6 +45,17 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
|
||||
bool is_scale_up = false, int stride = 32,
|
||||
float max_wh = 7680.0, bool multi_label = true);
|
||||
|
||||
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
|
||||
|
||||
bool CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info,
|
||||
const std::vector<int>& size = {640, 640},
|
||||
const std::vector<float> padding_value = {114.0, 114.0,
|
||||
114.0},
|
||||
bool is_mini_pad = false, bool is_no_pad = false,
|
||||
bool is_scale_up = false, int stride = 32,
|
||||
float max_wh = 7680.0, bool multi_label = true);
|
||||
|
||||
static bool Postprocess(
|
||||
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -85,6 +99,14 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
|
||||
// value will
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
// CUDA host buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_host_ = nullptr;
|
||||
// CUDA device buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_device_ = nullptr;
|
||||
// CUDA device buffer for TRT input tensor
|
||||
float* input_tensor_cuda_buffer_device_ = nullptr;
|
||||
// Whether to use CUDA preprocessing
|
||||
bool use_cuda_preprocessing_ = false;
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
|
@@ -27,6 +27,10 @@ void BindYOLOv5(pybind11::module& m) {
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def("use_cuda_preprocessing",
|
||||
[](vision::detection::YOLOv5& self, int max_image_size) {
|
||||
self.UseCudaPreprocessing(max_image_size);
|
||||
})
|
||||
.def_static("preprocess",
|
||||
[](pybind11::array& data, const std::vector<int>& size,
|
||||
const std::vector<float> padding_value, bool is_mini_pad,
|
||||
|
@@ -15,6 +15,9 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -136,6 +139,16 @@ bool YOLOv5Lite::Initialize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
YOLOv5Lite::~YOLOv5Lite() {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (use_cuda_preprocessing_) {
|
||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||
}
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv5Lite::Preprocess(
|
||||
Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
@@ -176,6 +189,65 @@ bool YOLOv5Lite::Preprocess(
|
||||
return true;
|
||||
}
|
||||
|
||||
void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
use_cuda_preprocessing_ = true;
|
||||
is_scale_up = true;
|
||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||
// prepare input data cache in GPU pinned memory
|
||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
||||
// prepare input data cache in GPU device memory
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
||||
}
|
||||
#else
|
||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
||||
<< std::endl;
|
||||
use_cuda_preprocessing_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
cudaStream_t stream;
|
||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||
input_img_cuda_buffer_host_,
|
||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||
size[0], size[1], padding_value, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// Record output shape of preprocessed image
|
||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
||||
|
||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||
input_tensor_cuda_buffer_device_);
|
||||
output->device = Device::GPU;
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
#else
|
||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||
return false;
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv5Lite::PostprocessWithDecode(
|
||||
FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -348,9 +420,16 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result,
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
if (use_cuda_preprocessing_) {
|
||||
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
|
@@ -27,12 +27,16 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel {
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::ONNX);
|
||||
|
||||
~YOLOv5Lite();
|
||||
|
||||
virtual std::string ModelName() const { return "YOLOv5-Lite"; }
|
||||
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold = 0.45,
|
||||
float nms_iou_threshold = 0.25);
|
||||
|
||||
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
|
||||
|
||||
// tuple of (width, height)
|
||||
std::vector<int> size;
|
||||
// padding value, size should be same with Channels
|
||||
@@ -79,6 +83,8 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel {
|
||||
bool Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -109,6 +115,14 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel {
|
||||
// value will
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
// CUDA host buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_host_ = nullptr;
|
||||
// CUDA device buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_device_ = nullptr;
|
||||
// CUDA device buffer for TRT input tensor
|
||||
float* input_tensor_cuda_buffer_device_ = nullptr;
|
||||
// Whether to use CUDA preprocessing
|
||||
bool use_cuda_preprocessing_ = false;
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
|
@@ -28,6 +28,10 @@ void BindYOLOv5Lite(pybind11::module& m) {
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def("use_cuda_preprocessing",
|
||||
[](vision::detection::YOLOv5Lite& self, int max_image_size) {
|
||||
self.UseCudaPreprocessing(max_image_size);
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv5Lite::size)
|
||||
.def_readwrite("padding_value",
|
||||
&vision::detection::YOLOv5Lite::padding_value)
|
||||
|
@@ -15,6 +15,9 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov6.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
@@ -108,6 +111,16 @@ bool YOLOv6::Initialize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
YOLOv6::~YOLOv6() {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (use_cuda_preprocessing_) {
|
||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||
}
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
// process after image load
|
||||
@@ -147,6 +160,65 @@ bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
|
||||
return true;
|
||||
}
|
||||
|
||||
void YOLOv6::UseCudaPreprocessing(int max_image_size) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
use_cuda_preprocessing_ = true;
|
||||
is_scale_up = true;
|
||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||
// prepare input data cache in GPU pinned memory
|
||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
||||
// prepare input data cache in GPU device memory
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
||||
}
|
||||
#else
|
||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
||||
<< std::endl;
|
||||
use_cuda_preprocessing_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool YOLOv6::CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
cudaStream_t stream;
|
||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||
input_img_cuda_buffer_host_,
|
||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||
size[0], size[1], padding_value, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// Record output shape of preprocessed image
|
||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
||||
|
||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||
input_tensor_cuda_buffer_device_);
|
||||
output->device = Device::GPU;
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
#else
|
||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||
return false;
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv6::Postprocess(
|
||||
FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -225,9 +297,16 @@ bool YOLOv6::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
if (use_cuda_preprocessing_) {
|
||||
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
|
@@ -30,12 +30,16 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel {
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::ONNX);
|
||||
|
||||
~YOLOv6();
|
||||
|
||||
std::string ModelName() const { return "YOLOv6"; }
|
||||
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold = 0.25,
|
||||
float nms_iou_threshold = 0.5);
|
||||
|
||||
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
|
||||
|
||||
// tuple of (width, height)
|
||||
std::vector<int> size;
|
||||
// padding value, size should be same with Channels
|
||||
@@ -60,6 +64,9 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel {
|
||||
bool Preprocess(Mat* mat, FDTensor* outputs,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
float conf_threshold, float nms_iou_threshold);
|
||||
@@ -78,6 +85,14 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel {
|
||||
// value will
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
// CUDA host buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_host_ = nullptr;
|
||||
// CUDA device buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_device_ = nullptr;
|
||||
// CUDA device buffer for TRT input tensor
|
||||
float* input_tensor_cuda_buffer_device_ = nullptr;
|
||||
// Whether to use CUDA preprocessing
|
||||
bool use_cuda_preprocessing_ = false;
|
||||
};
|
||||
|
||||
} // namespace detection
|
||||
|
@@ -27,6 +27,10 @@ void BindYOLOv6(pybind11::module& m) {
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def("use_cuda_preprocessing",
|
||||
[](vision::detection::YOLOv6& self, int max_image_size) {
|
||||
self.UseCudaPreprocessing(max_image_size);
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv6::size)
|
||||
.def_readwrite("padding_value", &vision::detection::YOLOv6::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::detection::YOLOv6::is_mini_pad)
|
||||
|
@@ -15,6 +15,9 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -106,6 +109,16 @@ bool YOLOv7::Initialize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
YOLOv7::~YOLOv7() {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (use_cuda_preprocessing_) {
|
||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||
}
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv7::Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
// process after image load
|
||||
@@ -145,6 +158,65 @@ bool YOLOv7::Preprocess(Mat* mat, FDTensor* output,
|
||||
return true;
|
||||
}
|
||||
|
||||
void YOLOv7::UseCudaPreprocessing(int max_image_size) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
use_cuda_preprocessing_ = true;
|
||||
is_scale_up = true;
|
||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||
// prepare input data cache in GPU pinned memory
|
||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
||||
// prepare input data cache in GPU device memory
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
||||
}
|
||||
#else
|
||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
||||
<< std::endl;
|
||||
use_cuda_preprocessing_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool YOLOv7::CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
cudaStream_t stream;
|
||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||
input_img_cuda_buffer_host_,
|
||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||
size[0], size[1], padding_value, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// Record output shape of preprocessed image
|
||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
||||
|
||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||
input_tensor_cuda_buffer_device_);
|
||||
output->device = Device::GPU;
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
#else
|
||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||
return false;
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv7::Postprocess(
|
||||
FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -227,9 +299,16 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
if (use_cuda_preprocessing_) {
|
||||
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
|
@@ -27,12 +27,16 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::ONNX);
|
||||
|
||||
~YOLOv7();
|
||||
|
||||
virtual std::string ModelName() const { return "yolov7"; }
|
||||
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold = 0.25,
|
||||
float nms_iou_threshold = 0.5);
|
||||
|
||||
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
|
||||
|
||||
// tuple of (width, height)
|
||||
std::vector<int> size;
|
||||
// padding value, size should be same with Channels
|
||||
@@ -56,6 +60,9 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||
bool Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool Postprocess(FDTensor& infer_result, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
float conf_threshold, float nms_iou_threshold);
|
||||
@@ -71,6 +78,14 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
|
||||
// value will
|
||||
// auto check by fastdeploy after the internal Runtime already initialized.
|
||||
bool is_dynamic_input_;
|
||||
// CUDA host buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_host_ = nullptr;
|
||||
// CUDA device buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_device_ = nullptr;
|
||||
// CUDA device buffer for TRT input tensor
|
||||
float* input_tensor_cuda_buffer_device_ = nullptr;
|
||||
// Whether to use CUDA preprocessing
|
||||
bool use_cuda_preprocessing_ = false;
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
|
@@ -27,6 +27,10 @@ void BindYOLOv7(pybind11::module& m) {
|
||||
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
|
||||
return res;
|
||||
})
|
||||
.def("use_cuda_preprocessing",
|
||||
[](vision::detection::YOLOv7& self, int max_image_size) {
|
||||
self.UseCudaPreprocessing(max_image_size);
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv7::size)
|
||||
.def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value)
|
||||
.def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad)
|
||||
|
@@ -15,6 +15,9 @@
|
||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -119,6 +122,16 @@ bool YOLOv7End2EndTRT::Initialize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
YOLOv7End2EndTRT::~YOLOv7End2EndTRT() {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (use_cuda_preprocessing_) {
|
||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||
}
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Preprocess(
|
||||
Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
@@ -150,6 +163,65 @@ bool YOLOv7End2EndTRT::Preprocess(
|
||||
return true;
|
||||
}
|
||||
|
||||
void YOLOv7End2EndTRT::UseCudaPreprocessing(int max_image_size) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
use_cuda_preprocessing_ = true;
|
||||
is_scale_up = true;
|
||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||
// prepare input data cache in GPU pinned memory
|
||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
||||
// prepare input data cache in GPU device memory
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
||||
}
|
||||
#else
|
||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
||||
<< std::endl;
|
||||
use_cuda_preprocessing_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||
#ifdef ENABLE_CUDA_PREPROCESS
|
||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Record the shape of image and the shape of preprocessed image
|
||||
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
|
||||
static_cast<float>(mat->Width())};
|
||||
|
||||
cudaStream_t stream;
|
||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||
input_img_cuda_buffer_host_,
|
||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||
size[0], size[1], padding_value, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
// Record output shape of preprocessed image
|
||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
||||
|
||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||
input_tensor_cuda_buffer_device_);
|
||||
output->device = Device::GPU;
|
||||
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
|
||||
return true;
|
||||
#else
|
||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||
return false;
|
||||
#endif // ENABLE_CUDA_PREPROCESS
|
||||
}
|
||||
|
||||
bool YOLOv7End2EndTRT::Postprocess(
|
||||
std::vector<FDTensor>& infer_results, DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -242,9 +314,16 @@ bool YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result,
|
||||
im_info["output_shape"] = {static_cast<float>(mat.Height()),
|
||||
static_cast<float>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
if (use_cuda_preprocessing_) {
|
||||
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
|
||||
FDERROR << "Failed to preprocess input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
input_tensors[0].name = InputInfoOfRuntime(0).name;
|
||||
|
@@ -28,11 +28,15 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel {
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::ONNX);
|
||||
|
||||
~YOLOv7End2EndTRT();
|
||||
|
||||
virtual std::string ModelName() const { return "yolov7end2end_trt"; }
|
||||
|
||||
virtual bool Predict(cv::Mat* im, DetectionResult* result,
|
||||
float conf_threshold = 0.25);
|
||||
|
||||
void UseCudaPreprocessing(int max_img_size = 3840 * 2160);
|
||||
|
||||
// tuple of (width, height)
|
||||
std::vector<int> size;
|
||||
// padding value, size should be same with Channels
|
||||
@@ -54,6 +58,9 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel {
|
||||
bool Preprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool CudaPreprocess(Mat* mat, FDTensor* output,
|
||||
std::map<std::string, std::array<float, 2>>* im_info);
|
||||
|
||||
bool Postprocess(std::vector<FDTensor>& infer_results,
|
||||
DetectionResult* result,
|
||||
const std::map<std::string, std::array<float, 2>>& im_info,
|
||||
@@ -65,6 +72,14 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel {
|
||||
int stride = 32);
|
||||
|
||||
bool is_dynamic_input_;
|
||||
// CUDA host buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_host_ = nullptr;
|
||||
// CUDA device buffer for input image
|
||||
uint8_t* input_img_cuda_buffer_device_ = nullptr;
|
||||
// CUDA device buffer for TRT input tensor
|
||||
float* input_tensor_cuda_buffer_device_ = nullptr;
|
||||
// Whether to use CUDA preprocessing
|
||||
bool use_cuda_preprocessing_ = false;
|
||||
};
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
|
@@ -28,6 +28,10 @@ void BindYOLOv7End2EndTRT(pybind11::module& m) {
|
||||
self.Predict(&mat, &res, conf_threshold);
|
||||
return res;
|
||||
})
|
||||
.def("use_cuda_preprocessing",
|
||||
[](vision::detection::YOLOv7End2EndTRT& self, int max_image_size) {
|
||||
self.UseCudaPreprocessing(max_image_size);
|
||||
})
|
||||
.def_readwrite("size", &vision::detection::YOLOv7End2EndTRT::size)
|
||||
.def_readwrite("padding_value",
|
||||
&vision::detection::YOLOv7End2EndTRT::padding_value)
|
||||
|
42
fastdeploy/vision/utils/cuda_utils.h
Normal file
42
fastdeploy/vision/utils/cuda_utils.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(callstr)\
|
||||
{\
|
||||
cudaError_t error_code = callstr;\
|
||||
if (error_code != cudaSuccess) {\
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":";\
|
||||
std::cerr << __LINE__;\
|
||||
assert(0);\
|
||||
}\
|
||||
}
|
||||
#endif // CUDA_CHECK
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace utils {
|
||||
void CudaYoloPreprocess(uint8_t* src, int src_width, int src_height,
|
||||
float* dst, int dst_width, int dst_height,
|
||||
const std::vector<float> padding_value,
|
||||
cudaStream_t stream);
|
||||
} // namespace utils
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
146
fastdeploy/vision/utils/yolo_preprocess.cu
Normal file
146
fastdeploy/vision/utils/yolo_preprocess.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Part of the following code in this file refs to
|
||||
// https://github.com/wang-xinyu/tensorrtx/blob/yolov5-v6.0/yolov5/preprocess.cu
|
||||
//
|
||||
// Copyright (c) 2022 tensorrtx
|
||||
// Licensed under The MIT License
|
||||
// \file preprocess.cu
|
||||
// \brief
|
||||
// \author Qi Liu, Xinyu Wang
|
||||
|
||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace utils {
|
||||
|
||||
struct AffineMatrix {
|
||||
float value[6];
|
||||
};
|
||||
|
||||
__global__ void YoloPreprocessCudaKernel(
|
||||
uint8_t* src, int src_line_size, int src_width,
|
||||
int src_height, float* dst, int dst_width,
|
||||
int dst_height, uint8_t padding_color_b,
|
||||
uint8_t padding_color_g, uint8_t padding_color_r,
|
||||
AffineMatrix d2s, int edge) {
|
||||
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (position >= edge) return;
|
||||
|
||||
float m_x1 = d2s.value[0];
|
||||
float m_y1 = d2s.value[1];
|
||||
float m_z1 = d2s.value[2];
|
||||
float m_x2 = d2s.value[3];
|
||||
float m_y2 = d2s.value[4];
|
||||
float m_z2 = d2s.value[5];
|
||||
|
||||
int dx = position % dst_width;
|
||||
int dy = position / dst_width;
|
||||
float src_x = m_x1 * dx + m_y1 * dy + m_z1 + 0.5f;
|
||||
float src_y = m_x2 * dx + m_y2 * dy + m_z2 + 0.5f;
|
||||
float c0, c1, c2;
|
||||
|
||||
if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) {
|
||||
// out of range
|
||||
c0 = padding_color_b;
|
||||
c1 = padding_color_g;
|
||||
c2 = padding_color_r;
|
||||
} else {
|
||||
int y_low = floorf(src_y);
|
||||
int x_low = floorf(src_x);
|
||||
int y_high = y_low + 1;
|
||||
int x_high = x_low + 1;
|
||||
|
||||
uint8_t const_value[] = {padding_color_b, padding_color_g, padding_color_r};
|
||||
float ly = src_y - y_low;
|
||||
float lx = src_x - x_low;
|
||||
float hy = 1 - ly;
|
||||
float hx = 1 - lx;
|
||||
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
uint8_t* v1 = const_value;
|
||||
uint8_t* v2 = const_value;
|
||||
uint8_t* v3 = const_value;
|
||||
uint8_t* v4 = const_value;
|
||||
|
||||
if (y_low >= 0) {
|
||||
if (x_low >= 0) v1 = src + y_low * src_line_size + x_low * 3;
|
||||
if (x_high < src_width) v2 = src + y_low * src_line_size + x_high * 3;
|
||||
}
|
||||
|
||||
if (y_high < src_height) {
|
||||
if (x_low >= 0) v3 = src + y_high * src_line_size + x_low * 3;
|
||||
if (x_high < src_width) v4 = src + y_high * src_line_size + x_high * 3;
|
||||
}
|
||||
|
||||
c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0];
|
||||
c1 = w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1];
|
||||
c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2];
|
||||
}
|
||||
|
||||
// bgr to rgb
|
||||
float t = c2;
|
||||
c2 = c0;
|
||||
c0 = t;
|
||||
|
||||
// normalization
|
||||
c0 = c0 / 255.0f;
|
||||
c1 = c1 / 255.0f;
|
||||
c2 = c2 / 255.0f;
|
||||
|
||||
// rgbrgbrgb to rrrgggbbb
|
||||
int area = dst_width * dst_height;
|
||||
float* pdst_c0 = dst + dy * dst_width + dx;
|
||||
float* pdst_c1 = pdst_c0 + area;
|
||||
float* pdst_c2 = pdst_c1 + area;
|
||||
*pdst_c0 = c0;
|
||||
*pdst_c1 = c1;
|
||||
*pdst_c2 = c2;
|
||||
}
|
||||
|
||||
void CudaYoloPreprocess(
|
||||
uint8_t* src, int src_width, int src_height,
|
||||
float* dst, int dst_width, int dst_height,
|
||||
const std::vector<float> padding_value, cudaStream_t stream) {
|
||||
AffineMatrix s2d, d2s;
|
||||
float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width);
|
||||
|
||||
s2d.value[0] = scale;
|
||||
s2d.value[1] = 0;
|
||||
s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5;
|
||||
s2d.value[3] = 0;
|
||||
s2d.value[4] = scale;
|
||||
s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5;
|
||||
|
||||
cv::Mat m2x3_s2d(2, 3, CV_32F, s2d.value);
|
||||
cv::Mat m2x3_d2s(2, 3, CV_32F, d2s.value);
|
||||
cv::invertAffineTransform(m2x3_s2d, m2x3_d2s);
|
||||
|
||||
memcpy(d2s.value, m2x3_d2s.ptr<float>(0), sizeof(d2s.value));
|
||||
|
||||
int jobs = dst_height * dst_width;
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
YoloPreprocessCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
src, src_width * 3, src_width,
|
||||
src_height, dst, dst_width,
|
||||
dst_height, padding_value[0], padding_value[1], padding_value[2], d2s, jobs);
|
||||
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -37,6 +37,9 @@ class YOLOv5(FastDeployModel):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
def use_cuda_preprocessing(self, max_image_size=3840 * 2160):
|
||||
return self._model.use_cuda_preprocessing(max_image_size)
|
||||
|
||||
@staticmethod
|
||||
def preprocess(input_image,
|
||||
size=[640, 640],
|
||||
|
@@ -37,6 +37,9 @@ class YOLOv5Lite(FastDeployModel):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
def use_cuda_preprocessing(self, max_image_size=3840 * 2160):
|
||||
return self._model.use_cuda_preprocessing(max_image_size)
|
||||
|
||||
# 一些跟YOLOv5Lite模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
|
@@ -37,6 +37,9 @@ class YOLOv6(FastDeployModel):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
def use_cuda_preprocessing(self, max_image_size=3840 * 2160):
|
||||
return self._model.use_cuda_preprocessing(max_image_size)
|
||||
|
||||
# 一些跟YOLOv6模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
|
@@ -37,6 +37,9 @@ class YOLOv7(FastDeployModel):
|
||||
return self._model.predict(input_image, conf_threshold,
|
||||
nms_iou_threshold)
|
||||
|
||||
def use_cuda_preprocessing(self, max_image_size=3840 * 2160):
|
||||
return self._model.use_cuda_preprocessing(max_image_size)
|
||||
|
||||
# 一些跟YOLOv7模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
|
@@ -36,6 +36,9 @@ class YOLOv7End2EndTRT(FastDeployModel):
|
||||
def predict(self, input_image, conf_threshold=0.25):
|
||||
return self._model.predict(input_image, conf_threshold)
|
||||
|
||||
def use_cuda_preprocessing(self, max_image_size=3840 * 2160):
|
||||
return self._model.use_cuda_preprocessing(max_image_size)
|
||||
|
||||
# 一些跟模型有关的属性封装
|
||||
# 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
|
||||
@property
|
||||
|
Reference in New Issue
Block a user