mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	[Other] Remove some build options (#1090)
* remove some flags * add gpu check in cmake
This commit is contained in:
		| @@ -184,8 +184,10 @@ add_definitions(-DFASTDEPLOY_LIB) | ||||
| configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h) | ||||
| configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc) | ||||
| file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc) | ||||
| file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu) | ||||
| file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/op_cuda_kernels/*.cu) | ||||
| if(WITH_GPU) | ||||
|   file(GLOB_RECURSE DEPLOY_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cu) | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_CUDA_SRCS}) | ||||
| endif() | ||||
| file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/ort/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/paddle/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/poros/*.cc) | ||||
| @@ -197,7 +199,6 @@ file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastd | ||||
| file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_ENCRYPTION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/encryption/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pipeline/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu) | ||||
| file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc) | ||||
| file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc) | ||||
| list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS} ${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS}) | ||||
| @@ -213,7 +214,8 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen) | ||||
| if(WIN32) | ||||
|   add_definitions(-DEIGEN_STRONG_INLINE=inline) | ||||
| endif() | ||||
| # sw not support thread_local semantic | ||||
|  | ||||
| # sw(sunway) not support thread_local semantic | ||||
| if(WITH_SW) | ||||
|   add_definitions(-DEIGEN_AVOID_THREAD_LOCAL) | ||||
| endif() | ||||
| @@ -224,9 +226,6 @@ if(ENABLE_ORT_BACKEND) | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS}) | ||||
|   include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake) | ||||
|   list(APPEND DEPEND_LIBS external_onnxruntime) | ||||
|   if(WITH_GPU) | ||||
|     list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS}) | ||||
|   endif() | ||||
| endif() | ||||
|  | ||||
| if(ENABLE_LITE_BACKEND) | ||||
| @@ -329,20 +328,13 @@ if(WITH_GPU) | ||||
|       find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64) | ||||
|     endif() | ||||
|     list(APPEND DEPEND_LIBS ${CUDA_LIB}) | ||||
|   endif() | ||||
| endif() | ||||
|  | ||||
| # Whether to build CUDA source files in fastdeploy | ||||
| # CUDA source files include CUDA preprocessing, TRT plugins, etc. | ||||
| if(WITH_GPU) | ||||
|   set(BUILD_CUDA_SRC ON) | ||||
|   enable_language(CUDA) | ||||
|   message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}, version: " | ||||
|                  "${CMAKE_CUDA_COMPILER_ID} ${CMAKE_CUDA_COMPILER_VERSION}") | ||||
|   include(${PROJECT_SOURCE_DIR}/cmake/cuda.cmake) | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${FDTENSOR_FUNC_CUDA_SRCS}) | ||||
| else() | ||||
|   set(BUILD_CUDA_SRC OFF) | ||||
|     # build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc. | ||||
|     enable_language(CUDA) | ||||
|     message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}, version: " | ||||
|                     "${CMAKE_CUDA_COMPILER_ID} ${CMAKE_CUDA_COMPILER_VERSION}") | ||||
|     include(${PROJECT_SOURCE_DIR}/cmake/cuda.cmake) | ||||
|   endif() | ||||
| endif() | ||||
|  | ||||
| if(WITH_IPU) | ||||
| @@ -383,7 +375,6 @@ if(ENABLE_TRT_BACKEND) | ||||
|   find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH) | ||||
|   find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH) | ||||
|   list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB}) | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS}) | ||||
|  | ||||
|   if(NOT BUILD_ON_JETSON AND TRT_DIRECTORY) | ||||
|     if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt") | ||||
| @@ -418,10 +409,6 @@ if(ENABLE_VISION) | ||||
|   add_definitions(-DENABLE_VISION_VISUALIZE) | ||||
|   add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp) | ||||
|   list(APPEND DEPEND_LIBS yaml-cpp) | ||||
|   if(BUILD_CUDA_SRC) | ||||
|     add_definitions(-DENABLE_CUDA_PREPROCESS) | ||||
|     list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS}) | ||||
|   endif() | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS}) | ||||
|   list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PIPELINE_SRCS}) | ||||
|   include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include) | ||||
| @@ -483,7 +470,7 @@ elseif(ANDROID) | ||||
|   set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL}) | ||||
| elseif(MSVC) | ||||
| else() | ||||
|   if(BUILD_CUDA_SRC) | ||||
|   if(WITH_GPU) | ||||
|     set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) | ||||
|     set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS | ||||
|        "$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>") | ||||
|   | ||||
| @@ -29,7 +29,6 @@ function(fastdeploy_summary) | ||||
|   message(STATUS "  CMAKE_MODULE_PATH         : ${CMAKE_MODULE_PATH}") | ||||
|   message(STATUS "") | ||||
|   message(STATUS "  FastDeploy version        : ${FASTDEPLOY_VERSION}") | ||||
|   message(STATUS "  Paddle2ONNX version       : ${PADDLE2ONNX_VERSION}") | ||||
|   message(STATUS "  ENABLE_ORT_BACKEND        : ${ENABLE_ORT_BACKEND}") | ||||
|   message(STATUS "  ENABLE_RKNPU2_BACKEND     : ${ENABLE_RKNPU2_BACKEND}") | ||||
|   message(STATUS "  ENABLE_SOPHGO_BACKEND     : ${ENABLE_SOPHGO_BACKEND}") | ||||
| @@ -38,6 +37,7 @@ function(fastdeploy_summary) | ||||
|   message(STATUS "  ENABLE_POROS_BACKEND      : ${ENABLE_POROS_BACKEND}") | ||||
|   message(STATUS "  ENABLE_TRT_BACKEND        : ${ENABLE_TRT_BACKEND}") | ||||
|   message(STATUS "  ENABLE_OPENVINO_BACKEND   : ${ENABLE_OPENVINO_BACKEND}") | ||||
|   message(STATUS "  WITH_GPU                  : ${WITH_GPU}") | ||||
|   message(STATUS "  WITH_ASCEND               : ${WITH_ASCEND}") | ||||
|   message(STATUS "  WITH_TIMVX                : ${WITH_TIMVX}") | ||||
|   message(STATUS "  WITH_KUNLUNXIN            : ${WITH_KUNLUNXIN}") | ||||
| @@ -54,16 +54,12 @@ function(fastdeploy_summary) | ||||
|     message(STATUS "  OpenVINO version          : ${OPENVINO_VERSION}") | ||||
|   endif() | ||||
|   if(WITH_GPU) | ||||
|     message(STATUS "  WITH_GPU                  : ${WITH_GPU}") | ||||
|     message(STATUS "  CUDA_DIRECTORY            : ${CUDA_DIRECTORY}") | ||||
|     message(STATUS "  TRT_DRECTORY              : ${TRT_DIRECTORY}") | ||||
|     message(STATUS "  BUILD_CUDA_SRC            : ${BUILD_CUDA_SRC}") | ||||
|   endif() | ||||
|   message(STATUS "  ENABLE_VISION             : ${ENABLE_VISION}") | ||||
|   message(STATUS "  ENABLE_TEXT               : ${ENABLE_TEXT}") | ||||
|   message(STATUS "  ENABLE_ENCRYPTION         : ${ENABLE_ENCRYPTION}") | ||||
|   message(STATUS "  ENABLE_DEBUG              : ${ENABLE_DEBUG}") | ||||
|   message(STATUS "  ENABLE_VISION_VISUALIZE   : ${ENABLE_VISION_VISUALIZE}") | ||||
|   if(ANDROID) | ||||
|     message(STATUS "  ANDROID_ABI               : ${ANDROID_ABI}") | ||||
|     message(STATUS "  ANDROID_PLATFORM          : ${ANDROID_PLATFORM}") | ||||
|   | ||||
| @@ -12,8 +12,8 @@ | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
|  | ||||
| #ifdef WITH_GPU | ||||
| #include "fastdeploy/function/cuda_cast.h" | ||||
|  | ||||
| namespace fastdeploy { | ||||
| namespace function { | ||||
| template <typename T_IN, typename T_OUT> | ||||
| @@ -30,13 +30,11 @@ void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) { | ||||
|   if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) { | ||||
|     CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>( | ||||
|         reinterpret_cast<int64_t*>(const_cast<void*>(in.Data())), | ||||
|         reinterpret_cast<int32_t*>(out->MutableData()), | ||||
|         jobs); | ||||
|         reinterpret_cast<int32_t*>(out->MutableData()), jobs); | ||||
|   } else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) { | ||||
|     CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>( | ||||
|         reinterpret_cast<int32_t*>(const_cast<void*>(in.Data())), | ||||
|         reinterpret_cast<int64_t*>(out->MutableData()), | ||||
|         jobs); | ||||
|         reinterpret_cast<int64_t*>(out->MutableData()), jobs); | ||||
|   } else { | ||||
|     FDASSERT(false, "CudaCast only support input INT64, output INT32."); | ||||
|   } | ||||
| @@ -44,3 +42,4 @@ void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) { | ||||
|  | ||||
| }  // namespace function | ||||
| }  // namespace fastdeploy | ||||
| #endif | ||||
|   | ||||
| @@ -1,3 +1,19 @@ | ||||
| // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
|  | ||||
| #ifdef WITH_GPU | ||||
|  | ||||
| #include "adaptive_pool2d_kernel.h" | ||||
|  | ||||
| namespace fastdeploy { | ||||
| @@ -53,4 +69,5 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims, | ||||
|       input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), | ||||
|       int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg); | ||||
| } | ||||
| }  // namespace fastdeploy | ||||
| }  // namespace fastdeploy | ||||
| #endif | ||||
|   | ||||
| @@ -12,14 +12,16 @@ | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
|  | ||||
| #ifdef WITH_GPU | ||||
| #include "fastdeploy/vision/common/processors/normalize_and_permute.h" | ||||
|  | ||||
| namespace fastdeploy { | ||||
| namespace vision { | ||||
|  | ||||
| __global__ void NormalizeAndPermuteKernel( | ||||
|     uint8_t* src, float* dst, const float* alpha, const float* beta, | ||||
|     int num_channel, bool swap_rb, int edge) { | ||||
| __global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst, | ||||
|                                           const float* alpha, const float* beta, | ||||
|                                           int num_channel, bool swap_rb, | ||||
|                                           int edge) { | ||||
|   int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|   if (idx >= edge) return; | ||||
|  | ||||
| @@ -38,8 +40,8 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) { | ||||
|   cv::Mat* im = mat->GetOpenCVMat(); | ||||
|   std::string buf_name = Name() + "_src"; | ||||
|   std::vector<int64_t> shape = {im->rows, im->cols, im->channels()}; | ||||
|   FDTensor* src = UpdateAndGetReusedBuffer(shape, im->type(), buf_name, | ||||
|                                            Device::GPU); | ||||
|   FDTensor* src = | ||||
|       UpdateAndGetReusedBuffer(shape, im->type(), buf_name, Device::GPU); | ||||
|   FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(), | ||||
|                       cudaMemcpyHostToDevice) == 0, | ||||
|            "Error occurs while copy memory from CPU to GPU."); | ||||
| @@ -58,7 +60,7 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) { | ||||
|  | ||||
|   buf_name = Name() + "_beta"; | ||||
|   FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1, | ||||
|                                              buf_name, Device::GPU); | ||||
|                                             buf_name, Device::GPU); | ||||
|   FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(), | ||||
|                       cudaMemcpyHostToDevice) == 0, | ||||
|            "Error occurs while copy memory from CPU to GPU."); | ||||
| @@ -80,3 +82,4 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) { | ||||
|  | ||||
| }  // namespace vision | ||||
| }  // namespace fastdeploy | ||||
| #endif | ||||
|   | ||||
							
								
								
									
										53
									
								
								fastdeploy/vision/detection/contrib/yolov5lite.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										53
									
								
								fastdeploy/vision/detection/contrib/yolov5lite.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							| @@ -13,11 +13,12 @@ | ||||
| // limitations under the License. | ||||
|  | ||||
| #include "fastdeploy/vision/detection/contrib/yolov5lite.h" | ||||
|  | ||||
| #include "fastdeploy/utils/perf.h" | ||||
| #include "fastdeploy/vision/utils/utils.h" | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
| #include "fastdeploy/vision/utils/cuda_utils.h" | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|  | ||||
| namespace fastdeploy { | ||||
| namespace vision { | ||||
| @@ -89,8 +90,8 @@ YOLOv5Lite::YOLOv5Lite(const std::string& model_file, | ||||
|                        const RuntimeOption& custom_option, | ||||
|                        const ModelFormat& model_format) { | ||||
|   if (model_format == ModelFormat::ONNX) { | ||||
|     valid_cpu_backends = {Backend::ORT};   | ||||
|     valid_gpu_backends = {Backend::ORT, Backend::TRT};   | ||||
|     valid_cpu_backends = {Backend::ORT}; | ||||
|     valid_gpu_backends = {Backend::ORT, Backend::TRT}; | ||||
|   } else { | ||||
|     valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; | ||||
|     valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; | ||||
| @@ -99,13 +100,13 @@ YOLOv5Lite::YOLOv5Lite(const std::string& model_file, | ||||
|   runtime_option.model_format = model_format; | ||||
|   runtime_option.model_file = model_file; | ||||
|   runtime_option.params_file = params_file; | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   cudaSetDevice(runtime_option.device_id); | ||||
|   cudaStream_t stream; | ||||
|   CUDA_CHECK(cudaStreamCreate(&stream)); | ||||
|   cuda_stream_ = reinterpret_cast<void*>(stream); | ||||
|   runtime_option.SetExternalStream(cuda_stream_); | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|   initialized = Initialize(); | ||||
| } | ||||
|  | ||||
| @@ -148,14 +149,14 @@ bool YOLOv5Lite::Initialize() { | ||||
| } | ||||
|  | ||||
| YOLOv5Lite::~YOLOv5Lite() { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   if (use_cuda_preprocessing_) { | ||||
|     CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); | ||||
|     CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_))); | ||||
|   } | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv5Lite::Preprocess( | ||||
| @@ -199,28 +200,33 @@ bool YOLOv5Lite::Preprocess( | ||||
| } | ||||
|  | ||||
| void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   use_cuda_preprocessing_ = true; | ||||
|   is_scale_up = true; | ||||
|   if (input_img_cuda_buffer_host_ == nullptr) { | ||||
|     // prepare input data cache in GPU pinned memory  | ||||
|     CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); | ||||
|     // prepare input data cache in GPU pinned memory | ||||
|     CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, | ||||
|                               max_image_size * 3)); | ||||
|     // prepare input data cache in GPU device memory | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); | ||||
|     CUDA_CHECK( | ||||
|         cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, | ||||
|                           3 * size[0] * size[1] * sizeof(float))); | ||||
|   } | ||||
| #else | ||||
|   FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." | ||||
|             << std::endl; | ||||
|   FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl; | ||||
|   use_cuda_preprocessing_ = false; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
|                                 std::map<std::string, std::array<float, 2>>* im_info) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| bool YOLOv5Lite::CudaPreprocess( | ||||
|     Mat* mat, FDTensor* output, | ||||
|     std::map<std::string, std::array<float, 2>>* im_info) { | ||||
| #ifdef WITH_GPU | ||||
|   if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { | ||||
|     FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; | ||||
|     FDERROR << "Preprocessing with CUDA is only available when the arguments " | ||||
|                "satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." | ||||
|             << std::endl; | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
| @@ -234,14 +240,15 @@ bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
|   int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); | ||||
|   memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); | ||||
|   CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, | ||||
|                              input_img_cuda_buffer_host_, | ||||
|                              src_img_buf_size, cudaMemcpyHostToDevice, stream)); | ||||
|                              input_img_cuda_buffer_host_, src_img_buf_size, | ||||
|                              cudaMemcpyHostToDevice, stream)); | ||||
|   utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), | ||||
|                             mat->Height(), input_tensor_cuda_buffer_device_, | ||||
|                             size[0], size[1], padding_value, stream); | ||||
|  | ||||
|   // Record output shape of preprocessed image | ||||
|   (*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])}; | ||||
|   (*im_info)["output_shape"] = {static_cast<float>(size[0]), | ||||
|                                 static_cast<float>(size[1])}; | ||||
|  | ||||
|   output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, | ||||
|                           input_tensor_cuda_buffer_device_); | ||||
| @@ -251,7 +258,7 @@ bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
| #else | ||||
|   FDERROR << "CUDA src code was not enabled." << std::endl; | ||||
|   return false; | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv5Lite::PostprocessWithDecode( | ||||
|   | ||||
							
								
								
									
										21
									
								
								fastdeploy/vision/detection/contrib/yolov6.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										21
									
								
								fastdeploy/vision/detection/contrib/yolov6.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							| @@ -16,9 +16,9 @@ | ||||
|  | ||||
| #include "fastdeploy/utils/perf.h" | ||||
| #include "fastdeploy/vision/utils/utils.h" | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
| #include "fastdeploy/vision/utils/cuda_utils.h" | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|  | ||||
| namespace fastdeploy { | ||||
|  | ||||
| @@ -79,13 +79,13 @@ YOLOv6::YOLOv6(const std::string& model_file, const std::string& params_file, | ||||
|   runtime_option.model_format = model_format; | ||||
|   runtime_option.model_file = model_file; | ||||
|   runtime_option.params_file = params_file; | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   cudaSetDevice(runtime_option.device_id); | ||||
|   cudaStream_t stream; | ||||
|   CUDA_CHECK(cudaStreamCreate(&stream)); | ||||
|   cuda_stream_ = reinterpret_cast<void*>(stream); | ||||
|   runtime_option.SetExternalStream(cuda_stream_); | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|   initialized = Initialize(); | ||||
| } | ||||
|  | ||||
| @@ -123,14 +123,14 @@ bool YOLOv6::Initialize() { | ||||
| } | ||||
|  | ||||
| YOLOv6::~YOLOv6() { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   if (use_cuda_preprocessing_) { | ||||
|     CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); | ||||
|     CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_))); | ||||
|   } | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, | ||||
| @@ -173,7 +173,7 @@ bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, | ||||
| } | ||||
|  | ||||
| void YOLOv6::UseCudaPreprocessing(int max_image_size) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   use_cuda_preprocessing_ = true; | ||||
|   is_scale_up = true; | ||||
|   if (input_img_cuda_buffer_host_ == nullptr) { | ||||
| @@ -187,8 +187,7 @@ void YOLOv6::UseCudaPreprocessing(int max_image_size) { | ||||
|                           3 * size[0] * size[1] * sizeof(float))); | ||||
|   } | ||||
| #else | ||||
|   FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." | ||||
|             << std::endl; | ||||
|   FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl; | ||||
|   use_cuda_preprocessing_ = false; | ||||
| #endif | ||||
| } | ||||
| @@ -196,7 +195,7 @@ void YOLOv6::UseCudaPreprocessing(int max_image_size) { | ||||
| bool YOLOv6::CudaPreprocess( | ||||
|     Mat* mat, FDTensor* output, | ||||
|     std::map<std::string, std::array<float, 2>>* im_info) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { | ||||
|     FDERROR << "Preprocessing with CUDA is only available when the arguments " | ||||
|                "satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." | ||||
| @@ -232,7 +231,7 @@ bool YOLOv6::CudaPreprocess( | ||||
| #else | ||||
|   FDERROR << "CUDA src code was not enabled." << std::endl; | ||||
|   return false; | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv6::Postprocess( | ||||
|   | ||||
							
								
								
									
										49
									
								
								fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										49
									
								
								fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							| @@ -13,11 +13,12 @@ | ||||
| // limitations under the License. | ||||
|  | ||||
| #include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h" | ||||
|  | ||||
| #include "fastdeploy/utils/perf.h" | ||||
| #include "fastdeploy/vision/utils/utils.h" | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
| #include "fastdeploy/vision/utils/cuda_utils.h" | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|  | ||||
| namespace fastdeploy { | ||||
| namespace vision { | ||||
| @@ -88,13 +89,13 @@ YOLOv7End2EndTRT::YOLOv7End2EndTRT(const std::string& model_file, | ||||
|       runtime_option.backend = Backend::TRT; | ||||
|     } | ||||
|   } | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   cudaSetDevice(runtime_option.device_id); | ||||
|   cudaStream_t stream; | ||||
|   CUDA_CHECK(cudaStreamCreate(&stream)); | ||||
|   cuda_stream_ = reinterpret_cast<void*>(stream); | ||||
|   runtime_option.SetExternalStream(cuda_stream_); | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
|   initialized = Initialize(); | ||||
| } | ||||
|  | ||||
| @@ -131,14 +132,14 @@ bool YOLOv7End2EndTRT::Initialize() { | ||||
| } | ||||
|  | ||||
| YOLOv7End2EndTRT::~YOLOv7End2EndTRT() { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   if (use_cuda_preprocessing_) { | ||||
|     CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_)); | ||||
|     CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_)); | ||||
|     CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_))); | ||||
|   } | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv7End2EndTRT::Preprocess( | ||||
| @@ -173,28 +174,33 @@ bool YOLOv7End2EndTRT::Preprocess( | ||||
| } | ||||
|  | ||||
| void YOLOv7End2EndTRT::UseCudaPreprocessing(int max_image_size) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| #ifdef WITH_GPU | ||||
|   use_cuda_preprocessing_ = true; | ||||
|   is_scale_up = true; | ||||
|   if (input_img_cuda_buffer_host_ == nullptr) { | ||||
|     // prepare input data cache in GPU pinned memory  | ||||
|     CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3)); | ||||
|     // prepare input data cache in GPU pinned memory | ||||
|     CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, | ||||
|                               max_image_size * 3)); | ||||
|     // prepare input data cache in GPU device memory | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float))); | ||||
|     CUDA_CHECK( | ||||
|         cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3)); | ||||
|     CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, | ||||
|                           3 * size[0] * size[1] * sizeof(float))); | ||||
|   } | ||||
| #else | ||||
|   FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON." | ||||
|             << std::endl; | ||||
|   FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl; | ||||
|   use_cuda_preprocessing_ = false; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
|                                       std::map<std::string, std::array<float, 2>>* im_info) { | ||||
| #ifdef ENABLE_CUDA_PREPROCESS | ||||
| bool YOLOv7End2EndTRT::CudaPreprocess( | ||||
|     Mat* mat, FDTensor* output, | ||||
|     std::map<std::string, std::array<float, 2>>* im_info) { | ||||
| #ifdef WITH_GPU | ||||
|   if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) { | ||||
|     FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl; | ||||
|     FDERROR << "Preprocessing with CUDA is only available when the arguments " | ||||
|                "satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." | ||||
|             << std::endl; | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
| @@ -208,14 +214,15 @@ bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
|   int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels(); | ||||
|   memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size); | ||||
|   CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_, | ||||
|                              input_img_cuda_buffer_host_, | ||||
|                              src_img_buf_size, cudaMemcpyHostToDevice, stream)); | ||||
|                              input_img_cuda_buffer_host_, src_img_buf_size, | ||||
|                              cudaMemcpyHostToDevice, stream)); | ||||
|   utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(), | ||||
|                             mat->Height(), input_tensor_cuda_buffer_device_, | ||||
|                             size[0], size[1], padding_value, stream); | ||||
|  | ||||
|   // Record output shape of preprocessed image | ||||
|   (*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])}; | ||||
|   (*im_info)["output_shape"] = {static_cast<float>(size[0]), | ||||
|                                 static_cast<float>(size[1])}; | ||||
|  | ||||
|   output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32, | ||||
|                           input_tensor_cuda_buffer_device_); | ||||
| @@ -225,7 +232,7 @@ bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output, | ||||
| #else | ||||
|   FDERROR << "CUDA src code was not enabled." << std::endl; | ||||
|   return false; | ||||
| #endif  // ENABLE_CUDA_PREPROCESS | ||||
| #endif  // WITH_GPU | ||||
| } | ||||
|  | ||||
| bool YOLOv7End2EndTRT::Postprocess( | ||||
|   | ||||
| @@ -21,9 +21,11 @@ | ||||
| // \brief | ||||
| // \author Qi Liu, Xinyu Wang | ||||
|  | ||||
| #include "fastdeploy/vision/utils/cuda_utils.h" | ||||
| #ifdef WITH_GPU | ||||
| #include <opencv2/opencv.hpp> | ||||
|  | ||||
| #include "fastdeploy/vision/utils/cuda_utils.h" | ||||
|  | ||||
| namespace fastdeploy { | ||||
| namespace vision { | ||||
| namespace utils { | ||||
| @@ -32,12 +34,11 @@ struct AffineMatrix { | ||||
|   float value[6]; | ||||
| }; | ||||
|  | ||||
| __global__ void YoloPreprocessCudaKernel(  | ||||
|     uint8_t* src, int src_line_size, int src_width,  | ||||
|     int src_height, float* dst, int dst_width,  | ||||
|     int dst_height, uint8_t padding_color_b, | ||||
|     uint8_t padding_color_g, uint8_t padding_color_r, | ||||
|     AffineMatrix d2s, int edge) { | ||||
| __global__ void YoloPreprocessCudaKernel( | ||||
|     uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, | ||||
|     int dst_width, int dst_height, uint8_t padding_color_b, | ||||
|     uint8_t padding_color_g, uint8_t padding_color_r, AffineMatrix d2s, | ||||
|     int edge) { | ||||
|   int position = blockDim.x * blockIdx.x + threadIdx.x; | ||||
|   if (position >= edge) return; | ||||
|  | ||||
| @@ -91,7 +92,7 @@ __global__ void YoloPreprocessCudaKernel( | ||||
|     c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2]; | ||||
|   } | ||||
|  | ||||
|   // bgr to rgb  | ||||
|   // bgr to rgb | ||||
|   float t = c2; | ||||
|   c2 = c0; | ||||
|   c0 = t; | ||||
| @@ -111,16 +112,17 @@ __global__ void YoloPreprocessCudaKernel( | ||||
|   *pdst_c2 = c2; | ||||
| } | ||||
|  | ||||
| void CudaYoloPreprocess( | ||||
|     uint8_t* src, int src_width, int src_height, | ||||
|     float* dst, int dst_width, int dst_height, | ||||
|     const std::vector<float> padding_value, cudaStream_t stream) { | ||||
| void CudaYoloPreprocess(uint8_t* src, int src_width, int src_height, float* dst, | ||||
|                         int dst_width, int dst_height, | ||||
|                         const std::vector<float> padding_value, | ||||
|                         cudaStream_t stream) { | ||||
|   AffineMatrix s2d, d2s; | ||||
|   float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width); | ||||
|   float scale = | ||||
|       std::min(dst_height / (float)src_height, dst_width / (float)src_width); | ||||
|  | ||||
|   s2d.value[0] = scale; | ||||
|   s2d.value[1] = 0; | ||||
|   s2d.value[2] = -scale * src_width  * 0.5  + dst_width * 0.5; | ||||
|   s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5; | ||||
|   s2d.value[3] = 0; | ||||
|   s2d.value[4] = scale; | ||||
|   s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5; | ||||
| @@ -135,12 +137,11 @@ void CudaYoloPreprocess( | ||||
|   int threads = 256; | ||||
|   int blocks = ceil(jobs / (float)threads); | ||||
|   YoloPreprocessCudaKernel<<<blocks, threads, 0, stream>>>( | ||||
|       src, src_width * 3, src_width, | ||||
|       src_height, dst, dst_width, | ||||
|       dst_height, padding_value[0], padding_value[1], padding_value[2], d2s, jobs); | ||||
|  | ||||
|       src, src_width * 3, src_width, src_height, dst, dst_width, dst_height, | ||||
|       padding_value[0], padding_value[1], padding_value[2], d2s, jobs); | ||||
| } | ||||
|  | ||||
| }  // namespace utils | ||||
| }  // namespace vision | ||||
| }  // namespace fastdeploy | ||||
| #endif | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jason
					Jason