diff --git a/CMakeLists.txt b/CMakeLists.txt index 418189e58..042d5645b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,16 @@ if(BUILD_ON_JETSON) set(ENABLE_ORT_BACKEND ON) endif() +# Whether to build CUDA source files in fastdeploy +# CUDA source files include CUDA preprocessing, TRT plugins, etc. +if(WITH_GPU AND UNIX) + set(BUILD_CUDA_SRC ON) + enable_language(CUDA) + set(CUDA_PROPAGATE_HOST_FLAGS FALSE) +else() + set(BUILD_CUDA_SRC OFF) +endif() + # config GIT_URL with github mirrors to speed up dependent repos clone option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL}) if(NOT GIT_URL) @@ -174,6 +184,7 @@ file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastde file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc) file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/lite/*.cc) file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc) +file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu) file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc) file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc) list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS}) @@ -373,6 +384,10 @@ if(ENABLE_VISION) endif() add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp) list(APPEND DEPEND_LIBS yaml-cpp) + if(BUILD_CUDA_SRC) + add_definitions(-DENABLE_CUDA_PREPROCESS) + list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS}) + endif() list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS}) include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include) include(${PROJECT_SOURCE_DIR}/cmake/opencv.cmake) @@ -428,7 +443,13 @@ elseif(ANDROID) set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL}) elseif(MSVC) else() - set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") + if(BUILD_CUDA_SRC) + set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS + "$<$>:-fvisibility=hidden>$<$>:-Xcompiler=-fvisibility=hidden>") + else() + set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") + endif() set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL") set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_RELEASE -s) endif() diff --git a/cmake/summary.cmake b/cmake/summary.cmake index 2deea28a7..fb4c534c6 100755 --- a/cmake/summary.cmake +++ b/cmake/summary.cmake @@ -51,6 +51,7 @@ function(fastdeploy_summary) message(STATUS " WITH_GPU : ${WITH_GPU}") message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}") message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}") + message(STATUS " BUILD_CUDA_SRC : ${BUILD_CUDA_SRC}") endif() message(STATUS " ENABLE_VISION : ${ENABLE_VISION}") message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}") diff --git a/fastdeploy/fastdeploy_model.cc b/fastdeploy/fastdeploy_model.cc index 7b7868a6b..ced300cab 100644 --- a/fastdeploy/fastdeploy_model.cc +++ b/fastdeploy/fastdeploy_model.cc @@ -57,7 +57,7 @@ bool FastDeployModel::InitRuntime() { } if (is_supported) { - runtime_ = std::unique_ptr(new Runtime()); + runtime_ = std::shared_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } @@ -107,7 +107,7 @@ bool FastDeployModel::CreateCpuBackend() { continue; } runtime_option.backend = valid_cpu_backends[i]; - runtime_ = std::unique_ptr(new Runtime()); + runtime_ = std::shared_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } @@ -130,7 +130,7 @@ bool FastDeployModel::CreateGpuBackend() { continue; } runtime_option.backend = valid_gpu_backends[i]; - runtime_ = std::unique_ptr(new Runtime()); + runtime_ = std::shared_ptr(new Runtime()); if (!runtime_->Init(runtime_option)) { return false; } diff --git a/fastdeploy/fastdeploy_model.h b/fastdeploy/fastdeploy_model.h index 63c735557..394a48c5d 100644 --- a/fastdeploy/fastdeploy_model.h +++ b/fastdeploy/fastdeploy_model.h @@ -99,7 +99,7 @@ class FASTDEPLOY_DECL FastDeployModel { std::vector valid_external_backends; private: - std::unique_ptr runtime_; + std::shared_ptr runtime_; bool runtime_initialized_ = false; // whether to record inference time bool enable_record_time_of_runtime_ = false; diff --git a/fastdeploy/vision/detection/contrib/yolov5.cc b/fastdeploy/vision/detection/contrib/yolov5.cc index e678b3c14..ee8f0cbf7 100644 --- a/fastdeploy/vision/detection/contrib/yolov5.cc +++ b/fastdeploy/vision/detection/contrib/yolov5.cc @@ -16,6 +16,9 @@ #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" +#ifdef ENABLE_CUDA_PREPROCESS +#include "fastdeploy/vision/utils/cuda_utils.h" +#endif // ENABLE_CUDA_PREPROCESS namespace fastdeploy { namespace vision { @@ -104,9 +107,20 @@ bool YOLOv5::Initialize() { // if (!is_dynamic_input_) { // is_mini_pad_ = false; // } + return true; } +YOLOv5::~YOLOv5() { +#ifdef ENABLE_CUDA_PREPROCESS + if (use_cuda_preprocessing_) { + CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); + CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); + CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); + } +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv5::Preprocess(Mat* mat, FDTensor* output, std::map>* im_info, const std::vector& size, @@ -156,6 +170,69 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output, return true; } +void YOLOv5::UseCudaPreprocessing(int max_image_size) { +#ifdef ENABLE_CUDA_PREPROCESS + use_cuda_preprocessing_ = true; + is_scale_up_ = true; + if (input_img_cuda_buffer_host_ == nullptr) { + // prepare input data cache in GPU pinned memory + CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); + // prepare input data cache in GPU device memory + CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size_[0] * size_[1] * sizeof(float))); + } +#else + FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." + << std::endl; + use_cuda_preprocessing_ = false; +#endif +} + +bool YOLOv5::CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info, + const std::vector& size, + const std::vector padding_value, + bool is_mini_pad, bool is_no_pad, bool is_scale_up, + int stride, float max_wh, bool multi_label) { +#ifdef ENABLE_CUDA_PREPROCESS + if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { + FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; + return false; + } + + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); + memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); + CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, + input_img_cuda_buffer_host_, + src_img_buf_size, cudaMemcpyHostToDevice, stream)); + utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), + mat->Height(), input_tensor_cuda_buffer_device_, + size[0], size[1], padding_value, stream); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(size[0]), static_cast(size[1])}; + + output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, + input_tensor_cuda_buffer_device_); + output->device = Device::GPU; + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +#else + FDERROR << "CUDA src code was not enabled." << std::endl; + return false; +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv5::Postprocess( std::vector& infer_results, DetectionResult* result, const std::map>& im_info, @@ -262,11 +339,20 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, std::map> im_info; - if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_, - is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_, - multi_label_)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; + if (use_cuda_preprocessing_) { + if (!CudaPreprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_, + is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_, + multi_label_)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } else { + if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_, + is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_, + multi_label_)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } } input_tensors[0].name = InputInfoOfRuntime(0).name; diff --git a/fastdeploy/vision/detection/contrib/yolov5.h b/fastdeploy/vision/detection/contrib/yolov5.h index 05aae90b1..198a38d4b 100644 --- a/fastdeploy/vision/detection/contrib/yolov5.h +++ b/fastdeploy/vision/detection/contrib/yolov5.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/vision/common/result.h" @@ -27,6 +28,8 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::ONNX); + ~YOLOv5(); + std::string ModelName() const { return "yolov5"; } virtual bool Predict(cv::Mat* im, DetectionResult* result, @@ -42,6 +45,17 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel { bool is_scale_up = false, int stride = 32, float max_wh = 7680.0, bool multi_label = true); + void UseCudaPreprocessing(int max_img_size = 3840 * 2160); + + bool CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info, + const std::vector& size = {640, 640}, + const std::vector padding_value = {114.0, 114.0, + 114.0}, + bool is_mini_pad = false, bool is_no_pad = false, + bool is_scale_up = false, int stride = 32, + float max_wh = 7680.0, bool multi_label = true); + static bool Postprocess( std::vector& infer_results, DetectionResult* result, const std::map>& im_info, @@ -85,6 +99,14 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel { // value will // auto check by fastdeploy after the internal Runtime already initialized. bool is_dynamic_input_; + // CUDA host buffer for input image + uint8_t* input_img_cuda_buffer_host_ = nullptr; + // CUDA device buffer for input image + uint8_t* input_img_cuda_buffer_device_ = nullptr; + // CUDA device buffer for TRT input tensor + float* input_tensor_cuda_buffer_device_ = nullptr; + // Whether to use CUDA preprocessing + bool use_cuda_preprocessing_ = false; }; } // namespace detection diff --git a/fastdeploy/vision/detection/contrib/yolov5_pybind.cc b/fastdeploy/vision/detection/contrib/yolov5_pybind.cc index 0f6d2a5c3..52d0d78c9 100644 --- a/fastdeploy/vision/detection/contrib/yolov5_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov5_pybind.cc @@ -27,6 +27,10 @@ void BindYOLOv5(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); return res; }) + .def("use_cuda_preprocessing", + [](vision::detection::YOLOv5& self, int max_image_size) { + self.UseCudaPreprocessing(max_image_size); + }) .def_static("preprocess", [](pybind11::array& data, const std::vector& size, const std::vector padding_value, bool is_mini_pad, diff --git a/fastdeploy/vision/detection/contrib/yolov5lite.cc b/fastdeploy/vision/detection/contrib/yolov5lite.cc index 33c93bb8c..e5417b02c 100644 --- a/fastdeploy/vision/detection/contrib/yolov5lite.cc +++ b/fastdeploy/vision/detection/contrib/yolov5lite.cc @@ -15,6 +15,9 @@ #include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" +#ifdef ENABLE_CUDA_PREPROCESS +#include "fastdeploy/vision/utils/cuda_utils.h" +#endif // ENABLE_CUDA_PREPROCESS namespace fastdeploy { namespace vision { @@ -136,6 +139,16 @@ bool YOLOv5Lite::Initialize() { return true; } +YOLOv5Lite::~YOLOv5Lite() { +#ifdef ENABLE_CUDA_PREPROCESS + if (use_cuda_preprocessing_) { + CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); + CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); + CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); + } +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv5Lite::Preprocess( Mat* mat, FDTensor* output, std::map>* im_info) { @@ -176,6 +189,65 @@ bool YOLOv5Lite::Preprocess( return true; } +void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) { +#ifdef ENABLE_CUDA_PREPROCESS + use_cuda_preprocessing_ = true; + is_scale_up = true; + if (input_img_cuda_buffer_host_ == nullptr) { + // prepare input data cache in GPU pinned memory + CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); + // prepare input data cache in GPU device memory + CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); + } +#else + FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." + << std::endl; + use_cuda_preprocessing_ = false; +#endif +} + +bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { +#ifdef ENABLE_CUDA_PREPROCESS + if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { + FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; + return false; + } + + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); + memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); + CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, + input_img_cuda_buffer_host_, + src_img_buf_size, cudaMemcpyHostToDevice, stream)); + utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), + mat->Height(), input_tensor_cuda_buffer_device_, + size[0], size[1], padding_value, stream); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(size[0]), static_cast(size[1])}; + + output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, + input_tensor_cuda_buffer_device_); + output->device = Device::GPU; + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +#else + FDERROR << "CUDA src code was not enabled." << std::endl; + return false; +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv5Lite::PostprocessWithDecode( FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, @@ -348,9 +420,16 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result, im_info["output_shape"] = {static_cast(mat.Height()), static_cast(mat.Width())}; - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; + if (use_cuda_preprocessing_) { + if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } else { + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } } input_tensors[0].name = InputInfoOfRuntime(0).name; diff --git a/fastdeploy/vision/detection/contrib/yolov5lite.h b/fastdeploy/vision/detection/contrib/yolov5lite.h index 0b8a88086..711880115 100644 --- a/fastdeploy/vision/detection/contrib/yolov5lite.h +++ b/fastdeploy/vision/detection/contrib/yolov5lite.h @@ -27,12 +27,16 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::ONNX); + ~YOLOv5Lite(); + virtual std::string ModelName() const { return "YOLOv5-Lite"; } virtual bool Predict(cv::Mat* im, DetectionResult* result, float conf_threshold = 0.45, float nms_iou_threshold = 0.25); + void UseCudaPreprocessing(int max_img_size = 3840 * 2160); + // tuple of (width, height) std::vector size; // padding value, size should be same with Channels @@ -79,7 +83,9 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel { bool Preprocess(Mat* mat, FDTensor* output, std::map>* im_info); - + bool CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + bool Postprocess(FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, float conf_threshold, float nms_iou_threshold); @@ -97,7 +103,7 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel { const std::vector& color, bool _auto, bool scale_fill = false, bool scale_up = true, int stride = 32); - + // generate anchors for decodeing when ONNX file without decode module. void GenerateAnchors(const std::vector& size, const std::vector& downsample_strides, @@ -109,6 +115,14 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel { // value will // auto check by fastdeploy after the internal Runtime already initialized. bool is_dynamic_input_; + // CUDA host buffer for input image + uint8_t* input_img_cuda_buffer_host_ = nullptr; + // CUDA device buffer for input image + uint8_t* input_img_cuda_buffer_device_ = nullptr; + // CUDA device buffer for TRT input tensor + float* input_tensor_cuda_buffer_device_ = nullptr; + // Whether to use CUDA preprocessing + bool use_cuda_preprocessing_ = false; }; } // namespace detection } // namespace vision diff --git a/fastdeploy/vision/detection/contrib/yolov5lite_pybind.cc b/fastdeploy/vision/detection/contrib/yolov5lite_pybind.cc index f74308abc..cdd20f169 100644 --- a/fastdeploy/vision/detection/contrib/yolov5lite_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov5lite_pybind.cc @@ -28,6 +28,10 @@ void BindYOLOv5Lite(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); return res; }) + .def("use_cuda_preprocessing", + [](vision::detection::YOLOv5Lite& self, int max_image_size) { + self.UseCudaPreprocessing(max_image_size); + }) .def_readwrite("size", &vision::detection::YOLOv5Lite::size) .def_readwrite("padding_value", &vision::detection::YOLOv5Lite::padding_value) diff --git a/fastdeploy/vision/detection/contrib/yolov6.cc b/fastdeploy/vision/detection/contrib/yolov6.cc index 95d4010ed..02b771fa5 100644 --- a/fastdeploy/vision/detection/contrib/yolov6.cc +++ b/fastdeploy/vision/detection/contrib/yolov6.cc @@ -15,6 +15,9 @@ #include "fastdeploy/vision/detection/contrib/yolov6.h" #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" +#ifdef ENABLE_CUDA_PREPROCESS +#include "fastdeploy/vision/utils/cuda_utils.h" +#endif // ENABLE_CUDA_PREPROCESS namespace fastdeploy { @@ -108,6 +111,16 @@ bool YOLOv6::Initialize() { return true; } +YOLOv6::~YOLOv6() { +#ifdef ENABLE_CUDA_PREPROCESS + if (use_cuda_preprocessing_) { + CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); + CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); + CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); + } +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, std::map>* im_info) { // process after image load @@ -147,6 +160,65 @@ bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, return true; } +void YOLOv6::UseCudaPreprocessing(int max_image_size) { +#ifdef ENABLE_CUDA_PREPROCESS + use_cuda_preprocessing_ = true; + is_scale_up = true; + if (input_img_cuda_buffer_host_ == nullptr) { + // prepare input data cache in GPU pinned memory + CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); + // prepare input data cache in GPU device memory + CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); + } +#else + FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." + << std::endl; + use_cuda_preprocessing_ = false; +#endif +} + +bool YOLOv6::CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { +#ifdef ENABLE_CUDA_PREPROCESS + if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { + FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; + return false; + } + + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); + memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); + CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, + input_img_cuda_buffer_host_, + src_img_buf_size, cudaMemcpyHostToDevice, stream)); + utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), + mat->Height(), input_tensor_cuda_buffer_device_, + size[0], size[1], padding_value, stream); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(size[0]), static_cast(size[1])}; + + output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, + input_tensor_cuda_buffer_device_); + output->device = Device::GPU; + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +#else + FDERROR << "CUDA src code was not enabled." << std::endl; + return false; +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv6::Postprocess( FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, @@ -225,9 +297,16 @@ bool YOLOv6::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, im_info["output_shape"] = {static_cast(mat.Height()), static_cast(mat.Width())}; - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; + if (use_cuda_preprocessing_) { + if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } else { + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } } input_tensors[0].name = InputInfoOfRuntime(0).name; diff --git a/fastdeploy/vision/detection/contrib/yolov6.h b/fastdeploy/vision/detection/contrib/yolov6.h index f951e65a1..68a224c84 100644 --- a/fastdeploy/vision/detection/contrib/yolov6.h +++ b/fastdeploy/vision/detection/contrib/yolov6.h @@ -30,12 +30,16 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::ONNX); + ~YOLOv6(); + std::string ModelName() const { return "YOLOv6"; } virtual bool Predict(cv::Mat* im, DetectionResult* result, float conf_threshold = 0.25, float nms_iou_threshold = 0.5); + void UseCudaPreprocessing(int max_img_size = 3840 * 2160); + // tuple of (width, height) std::vector size; // padding value, size should be same with Channels @@ -60,6 +64,9 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel { bool Preprocess(Mat* mat, FDTensor* outputs, std::map>* im_info); + bool CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + bool Postprocess(FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, float conf_threshold, float nms_iou_threshold); @@ -78,6 +85,14 @@ class FASTDEPLOY_DECL YOLOv6 : public FastDeployModel { // value will // auto check by fastdeploy after the internal Runtime already initialized. bool is_dynamic_input_; +// CUDA host buffer for input image + uint8_t* input_img_cuda_buffer_host_ = nullptr; + // CUDA device buffer for input image + uint8_t* input_img_cuda_buffer_device_ = nullptr; + // CUDA device buffer for TRT input tensor + float* input_tensor_cuda_buffer_device_ = nullptr; + // Whether to use CUDA preprocessing + bool use_cuda_preprocessing_ = false; }; } // namespace detection diff --git a/fastdeploy/vision/detection/contrib/yolov6_pybind.cc b/fastdeploy/vision/detection/contrib/yolov6_pybind.cc index 4414d6dcc..b4f2692bd 100644 --- a/fastdeploy/vision/detection/contrib/yolov6_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov6_pybind.cc @@ -27,6 +27,10 @@ void BindYOLOv6(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); return res; }) + .def("use_cuda_preprocessing", + [](vision::detection::YOLOv6& self, int max_image_size) { + self.UseCudaPreprocessing(max_image_size); + }) .def_readwrite("size", &vision::detection::YOLOv6::size) .def_readwrite("padding_value", &vision::detection::YOLOv6::padding_value) .def_readwrite("is_mini_pad", &vision::detection::YOLOv6::is_mini_pad) diff --git a/fastdeploy/vision/detection/contrib/yolov7.cc b/fastdeploy/vision/detection/contrib/yolov7.cc index 51e7a605c..5df1b49be 100644 --- a/fastdeploy/vision/detection/contrib/yolov7.cc +++ b/fastdeploy/vision/detection/contrib/yolov7.cc @@ -15,6 +15,9 @@ #include "fastdeploy/vision/detection/contrib/yolov7.h" #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" +#ifdef ENABLE_CUDA_PREPROCESS +#include "fastdeploy/vision/utils/cuda_utils.h" +#endif // ENABLE_CUDA_PREPROCESS namespace fastdeploy { namespace vision { @@ -106,6 +109,16 @@ bool YOLOv7::Initialize() { return true; } +YOLOv7::~YOLOv7() { +#ifdef ENABLE_CUDA_PREPROCESS + if (use_cuda_preprocessing_) { + CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); + CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); + CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); + } +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv7::Preprocess(Mat* mat, FDTensor* output, std::map>* im_info) { // process after image load @@ -145,6 +158,65 @@ bool YOLOv7::Preprocess(Mat* mat, FDTensor* output, return true; } +void YOLOv7::UseCudaPreprocessing(int max_image_size) { +#ifdef ENABLE_CUDA_PREPROCESS + use_cuda_preprocessing_ = true; + is_scale_up = true; + if (input_img_cuda_buffer_host_ == nullptr) { + // prepare input data cache in GPU pinned memory + CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); + // prepare input data cache in GPU device memory + CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); + } +#else + FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." + << std::endl; + use_cuda_preprocessing_ = false; +#endif +} + +bool YOLOv7::CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { +#ifdef ENABLE_CUDA_PREPROCESS + if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { + FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; + return false; + } + + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); + memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); + CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, + input_img_cuda_buffer_host_, + src_img_buf_size, cudaMemcpyHostToDevice, stream)); + utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), + mat->Height(), input_tensor_cuda_buffer_device_, + size[0], size[1], padding_value, stream); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(size[0]), static_cast(size[1])}; + + output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, + input_tensor_cuda_buffer_device_); + output->device = Device::GPU; + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +#else + FDERROR << "CUDA src code was not enabled." << std::endl; + return false; +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv7::Postprocess( FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, @@ -227,9 +299,16 @@ bool YOLOv7::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, im_info["output_shape"] = {static_cast(mat.Height()), static_cast(mat.Width())}; - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; + if (use_cuda_preprocessing_) { + if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } else { + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } } input_tensors[0].name = InputInfoOfRuntime(0).name; diff --git a/fastdeploy/vision/detection/contrib/yolov7.h b/fastdeploy/vision/detection/contrib/yolov7.h index 009050aba..872ff8dda 100644 --- a/fastdeploy/vision/detection/contrib/yolov7.h +++ b/fastdeploy/vision/detection/contrib/yolov7.h @@ -27,12 +27,16 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::ONNX); + ~YOLOv7(); + virtual std::string ModelName() const { return "yolov7"; } virtual bool Predict(cv::Mat* im, DetectionResult* result, float conf_threshold = 0.25, float nms_iou_threshold = 0.5); + void UseCudaPreprocessing(int max_img_size = 3840 * 2160); + // tuple of (width, height) std::vector size; // padding value, size should be same with Channels @@ -56,6 +60,9 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { bool Preprocess(Mat* mat, FDTensor* output, std::map>* im_info); + bool CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + bool Postprocess(FDTensor& infer_result, DetectionResult* result, const std::map>& im_info, float conf_threshold, float nms_iou_threshold); @@ -71,6 +78,14 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { // value will // auto check by fastdeploy after the internal Runtime already initialized. bool is_dynamic_input_; + // CUDA host buffer for input image + uint8_t* input_img_cuda_buffer_host_ = nullptr; + // CUDA device buffer for input image + uint8_t* input_img_cuda_buffer_device_ = nullptr; + // CUDA device buffer for TRT input tensor + float* input_tensor_cuda_buffer_device_ = nullptr; + // Whether to use CUDA preprocessing + bool use_cuda_preprocessing_ = false; }; } // namespace detection } // namespace vision diff --git a/fastdeploy/vision/detection/contrib/yolov7_pybind.cc b/fastdeploy/vision/detection/contrib/yolov7_pybind.cc index 37f375e5f..d7ab99340 100644 --- a/fastdeploy/vision/detection/contrib/yolov7_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov7_pybind.cc @@ -27,6 +27,10 @@ void BindYOLOv7(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); return res; }) + .def("use_cuda_preprocessing", + [](vision::detection::YOLOv7& self, int max_image_size) { + self.UseCudaPreprocessing(max_image_size); + }) .def_readwrite("size", &vision::detection::YOLOv7::size) .def_readwrite("padding_value", &vision::detection::YOLOv7::padding_value) .def_readwrite("is_mini_pad", &vision::detection::YOLOv7::is_mini_pad) diff --git a/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc b/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc index a47e7e203..fdefd2d12 100644 --- a/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc +++ b/fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc @@ -15,6 +15,9 @@ #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" +#ifdef ENABLE_CUDA_PREPROCESS +#include "fastdeploy/vision/utils/cuda_utils.h" +#endif // ENABLE_CUDA_PREPROCESS namespace fastdeploy { namespace vision { @@ -119,6 +122,16 @@ bool YOLOv7End2EndTRT::Initialize() { return true; } +YOLOv7End2EndTRT::~YOLOv7End2EndTRT() { +#ifdef ENABLE_CUDA_PREPROCESS + if (use_cuda_preprocessing_) { + CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); + CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); + CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); + } +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv7End2EndTRT::Preprocess( Mat* mat, FDTensor* output, std::map>* im_info) { @@ -150,6 +163,65 @@ bool YOLOv7End2EndTRT::Preprocess( return true; } +void YOLOv7End2EndTRT::UseCudaPreprocessing(int max_image_size) { +#ifdef ENABLE_CUDA_PREPROCESS + use_cuda_preprocessing_ = true; + is_scale_up = true; + if (input_img_cuda_buffer_host_ == nullptr) { + // prepare input data cache in GPU pinned memory + CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); + // prepare input data cache in GPU device memory + CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); + CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); + } +#else + FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." + << std::endl; + use_cuda_preprocessing_ = false; +#endif +} + +bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { +#ifdef ENABLE_CUDA_PREPROCESS + if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { + FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; + return false; + } + + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); + memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); + CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, + input_img_cuda_buffer_host_, + src_img_buf_size, cudaMemcpyHostToDevice, stream)); + utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), + mat->Height(), input_tensor_cuda_buffer_device_, + size[0], size[1], padding_value, stream); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(size[0]), static_cast(size[1])}; + + output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, + input_tensor_cuda_buffer_device_); + output->device = Device::GPU; + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +#else + FDERROR << "CUDA src code was not enabled." << std::endl; + return false; +#endif // ENABLE_CUDA_PREPROCESS +} + bool YOLOv7End2EndTRT::Postprocess( std::vector& infer_results, DetectionResult* result, const std::map>& im_info, @@ -242,9 +314,16 @@ bool YOLOv7End2EndTRT::Predict(cv::Mat* im, DetectionResult* result, im_info["output_shape"] = {static_cast(mat.Height()), static_cast(mat.Width())}; - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { - FDERROR << "Failed to preprocess input image." << std::endl; - return false; + if (use_cuda_preprocessing_) { + if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + } else { + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } } input_tensors[0].name = InputInfoOfRuntime(0).name; diff --git a/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h b/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h index 8b97b8090..7398679dd 100644 --- a/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h +++ b/fastdeploy/vision/detection/contrib/yolov7end2end_trt.h @@ -28,11 +28,15 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel { const RuntimeOption& custom_option = RuntimeOption(), const ModelFormat& model_format = ModelFormat::ONNX); + ~YOLOv7End2EndTRT(); + virtual std::string ModelName() const { return "yolov7end2end_trt"; } virtual bool Predict(cv::Mat* im, DetectionResult* result, float conf_threshold = 0.25); + void UseCudaPreprocessing(int max_img_size = 3840 * 2160); + // tuple of (width, height) std::vector size; // padding value, size should be same with Channels @@ -54,6 +58,9 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel { bool Preprocess(Mat* mat, FDTensor* output, std::map>* im_info); + bool CudaPreprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + bool Postprocess(std::vector& infer_results, DetectionResult* result, const std::map>& im_info, @@ -65,6 +72,14 @@ class FASTDEPLOY_DECL YOLOv7End2EndTRT : public FastDeployModel { int stride = 32); bool is_dynamic_input_; + // CUDA host buffer for input image + uint8_t* input_img_cuda_buffer_host_ = nullptr; + // CUDA device buffer for input image + uint8_t* input_img_cuda_buffer_device_ = nullptr; + // CUDA device buffer for TRT input tensor + float* input_tensor_cuda_buffer_device_ = nullptr; + // Whether to use CUDA preprocessing + bool use_cuda_preprocessing_ = false; }; } // namespace detection } // namespace vision diff --git a/fastdeploy/vision/detection/contrib/yolov7end2end_trt_pybind.cc b/fastdeploy/vision/detection/contrib/yolov7end2end_trt_pybind.cc index 9a7aeb8dd..b22f16693 100644 --- a/fastdeploy/vision/detection/contrib/yolov7end2end_trt_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov7end2end_trt_pybind.cc @@ -28,6 +28,10 @@ void BindYOLOv7End2EndTRT(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold); return res; }) + .def("use_cuda_preprocessing", + [](vision::detection::YOLOv7End2EndTRT& self, int max_image_size) { + self.UseCudaPreprocessing(max_image_size); + }) .def_readwrite("size", &vision::detection::YOLOv7End2EndTRT::size) .def_readwrite("padding_value", &vision::detection::YOLOv7End2EndTRT::padding_value) diff --git a/fastdeploy/vision/utils/cuda_utils.h b/fastdeploy/vision/utils/cuda_utils.h new file mode 100644 index 000000000..0f0f414e9 --- /dev/null +++ b/fastdeploy/vision/utils/cuda_utils.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr)\ + {\ + cudaError_t error_code = callstr;\ + if (error_code != cudaSuccess) {\ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":";\ + std::cerr << __LINE__;\ + assert(0);\ + }\ + } +#endif // CUDA_CHECK + +namespace fastdeploy { +namespace vision { +namespace utils { +void CudaYoloPreprocess(uint8_t* src, int src_width, int src_height, + float* dst, int dst_width, int dst_height, + const std::vector padding_value, + cudaStream_t stream); +} // namespace utils +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/utils/yolo_preprocess.cu b/fastdeploy/vision/utils/yolo_preprocess.cu new file mode 100644 index 000000000..9bae00d42 --- /dev/null +++ b/fastdeploy/vision/utils/yolo_preprocess.cu @@ -0,0 +1,146 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/wang-xinyu/tensorrtx/blob/yolov5-v6.0/yolov5/preprocess.cu +// +// Copyright (c) 2022 tensorrtx +// Licensed under The MIT License +// \file preprocess.cu +// \brief +// \author Qi Liu, Xinyu Wang + +#include "fastdeploy/vision/utils/cuda_utils.h" +#include + +namespace fastdeploy { +namespace vision { +namespace utils { + +struct AffineMatrix { + float value[6]; +}; + +__global__ void YoloPreprocessCudaKernel( + uint8_t* src, int src_line_size, int src_width, + int src_height, float* dst, int dst_width, + int dst_height, uint8_t padding_color_b, + uint8_t padding_color_g, uint8_t padding_color_r, + AffineMatrix d2s, int edge) { + int position = blockDim.x * blockIdx.x + threadIdx.x; + if (position >= edge) return; + + float m_x1 = d2s.value[0]; + float m_y1 = d2s.value[1]; + float m_z1 = d2s.value[2]; + float m_x2 = d2s.value[3]; + float m_y2 = d2s.value[4]; + float m_z2 = d2s.value[5]; + + int dx = position % dst_width; + int dy = position / dst_width; + float src_x = m_x1 * dx + m_y1 * dy + m_z1 + 0.5f; + float src_y = m_x2 * dx + m_y2 * dy + m_z2 + 0.5f; + float c0, c1, c2; + + if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) { + // out of range + c0 = padding_color_b; + c1 = padding_color_g; + c2 = padding_color_r; + } else { + int y_low = floorf(src_y); + int x_low = floorf(src_x); + int y_high = y_low + 1; + int x_high = x_low + 1; + + uint8_t const_value[] = {padding_color_b, padding_color_g, padding_color_r}; + float ly = src_y - y_low; + float lx = src_x - x_low; + float hy = 1 - ly; + float hx = 1 - lx; + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + uint8_t* v1 = const_value; + uint8_t* v2 = const_value; + uint8_t* v3 = const_value; + uint8_t* v4 = const_value; + + if (y_low >= 0) { + if (x_low >= 0) v1 = src + y_low * src_line_size + x_low * 3; + if (x_high < src_width) v2 = src + y_low * src_line_size + x_high * 3; + } + + if (y_high < src_height) { + if (x_low >= 0) v3 = src + y_high * src_line_size + x_low * 3; + if (x_high < src_width) v4 = src + y_high * src_line_size + x_high * 3; + } + + c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0]; + c1 = w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1]; + c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2]; + } + + // bgr to rgb + float t = c2; + c2 = c0; + c0 = t; + + // normalization + c0 = c0 / 255.0f; + c1 = c1 / 255.0f; + c2 = c2 / 255.0f; + + // rgbrgbrgb to rrrgggbbb + int area = dst_width * dst_height; + float* pdst_c0 = dst + dy * dst_width + dx; + float* pdst_c1 = pdst_c0 + area; + float* pdst_c2 = pdst_c1 + area; + *pdst_c0 = c0; + *pdst_c1 = c1; + *pdst_c2 = c2; +} + +void CudaYoloPreprocess( + uint8_t* src, int src_width, int src_height, + float* dst, int dst_width, int dst_height, + const std::vector padding_value, cudaStream_t stream) { + AffineMatrix s2d, d2s; + float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width); + + s2d.value[0] = scale; + s2d.value[1] = 0; + s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5; + s2d.value[3] = 0; + s2d.value[4] = scale; + s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5; + + cv::Mat m2x3_s2d(2, 3, CV_32F, s2d.value); + cv::Mat m2x3_d2s(2, 3, CV_32F, d2s.value); + cv::invertAffineTransform(m2x3_s2d, m2x3_d2s); + + memcpy(d2s.value, m2x3_d2s.ptr(0), sizeof(d2s.value)); + + int jobs = dst_height * dst_width; + int threads = 256; + int blocks = ceil(jobs / (float)threads); + YoloPreprocessCudaKernel<<>>( + src, src_width * 3, src_width, + src_height, dst, dst_width, + dst_height, padding_value[0], padding_value[1], padding_value[2], d2s, jobs); + +} + +} // namespace utils +} // namespace vision +} // namespace fastdeploy diff --git a/python/fastdeploy/vision/detection/contrib/yolov5.py b/python/fastdeploy/vision/detection/contrib/yolov5.py index 2f4d00b82..9fdd2b77e 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov5.py +++ b/python/fastdeploy/vision/detection/contrib/yolov5.py @@ -37,6 +37,9 @@ class YOLOv5(FastDeployModel): return self._model.predict(input_image, conf_threshold, nms_iou_threshold) + def use_cuda_preprocessing(self, max_image_size=3840 * 2160): + return self._model.use_cuda_preprocessing(max_image_size) + @staticmethod def preprocess(input_image, size=[640, 640], diff --git a/python/fastdeploy/vision/detection/contrib/yolov5lite.py b/python/fastdeploy/vision/detection/contrib/yolov5lite.py index b113e565f..f50277740 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov5lite.py +++ b/python/fastdeploy/vision/detection/contrib/yolov5lite.py @@ -37,6 +37,9 @@ class YOLOv5Lite(FastDeployModel): return self._model.predict(input_image, conf_threshold, nms_iou_threshold) + def use_cuda_preprocessing(self, max_image_size=3840 * 2160): + return self._model.use_cuda_preprocessing(max_image_size) + # 一些跟YOLOv5Lite模型有关的属性封装 # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property diff --git a/python/fastdeploy/vision/detection/contrib/yolov6.py b/python/fastdeploy/vision/detection/contrib/yolov6.py index 73ab26d94..5e09de4d6 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov6.py +++ b/python/fastdeploy/vision/detection/contrib/yolov6.py @@ -37,6 +37,9 @@ class YOLOv6(FastDeployModel): return self._model.predict(input_image, conf_threshold, nms_iou_threshold) + def use_cuda_preprocessing(self, max_image_size=3840 * 2160): + return self._model.use_cuda_preprocessing(max_image_size) + # 一些跟YOLOv6模型有关的属性封装 # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property diff --git a/python/fastdeploy/vision/detection/contrib/yolov7.py b/python/fastdeploy/vision/detection/contrib/yolov7.py index f255548d9..a102fd83a 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov7.py +++ b/python/fastdeploy/vision/detection/contrib/yolov7.py @@ -37,6 +37,9 @@ class YOLOv7(FastDeployModel): return self._model.predict(input_image, conf_threshold, nms_iou_threshold) + def use_cuda_preprocessing(self, max_image_size=3840 * 2160): + return self._model.use_cuda_preprocessing(max_image_size) + # 一些跟YOLOv7模型有关的属性封装 # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property diff --git a/python/fastdeploy/vision/detection/contrib/yolov7end2end_trt.py b/python/fastdeploy/vision/detection/contrib/yolov7end2end_trt.py index ea596020b..d0c2e90ac 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov7end2end_trt.py +++ b/python/fastdeploy/vision/detection/contrib/yolov7end2end_trt.py @@ -36,6 +36,9 @@ class YOLOv7End2EndTRT(FastDeployModel): def predict(self, input_image, conf_threshold=0.25): return self._model.predict(input_image, conf_threshold) + def use_cuda_preprocessing(self, max_image_size=3840 * 2160): + return self._model.use_cuda_preprocessing(max_image_size) + # 一些跟模型有关的属性封装 # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property