mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 12:00:30 +08:00
[Other] Remove some build options (#1090)
* remove some flags * add gpu check in cmake
This commit is contained in:
@@ -184,8 +184,10 @@ add_definitions(-DFASTDEPLOY_LIB)
|
|||||||
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h)
|
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h)
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc)
|
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc)
|
||||||
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
|
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
|
||||||
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
|
if(WITH_GPU)
|
||||||
file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/op_cuda_kernels/*.cu)
|
file(GLOB_RECURSE DEPLOY_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cu)
|
||||||
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_CUDA_SRCS})
|
||||||
|
endif()
|
||||||
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/ort/*.cc)
|
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/ort/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/paddle/*.cc)
|
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/paddle/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/poros/*.cc)
|
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/poros/*.cc)
|
||||||
@@ -197,7 +199,6 @@ file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastd
|
|||||||
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
|
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_ENCRYPTION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/encryption/*.cc)
|
file(GLOB_RECURSE DEPLOY_ENCRYPTION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/encryption/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pipeline/*.cc)
|
file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pipeline/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu)
|
|
||||||
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
|
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
|
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
|
||||||
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS} ${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS})
|
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS} ${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS})
|
||||||
@@ -213,7 +214,8 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)
|
|||||||
if(WIN32)
|
if(WIN32)
|
||||||
add_definitions(-DEIGEN_STRONG_INLINE=inline)
|
add_definitions(-DEIGEN_STRONG_INLINE=inline)
|
||||||
endif()
|
endif()
|
||||||
# sw not support thread_local semantic
|
|
||||||
|
# sw(sunway) not support thread_local semantic
|
||||||
if(WITH_SW)
|
if(WITH_SW)
|
||||||
add_definitions(-DEIGEN_AVOID_THREAD_LOCAL)
|
add_definitions(-DEIGEN_AVOID_THREAD_LOCAL)
|
||||||
endif()
|
endif()
|
||||||
@@ -224,9 +226,6 @@ if(ENABLE_ORT_BACKEND)
|
|||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS})
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS})
|
||||||
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
|
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
|
||||||
list(APPEND DEPEND_LIBS external_onnxruntime)
|
list(APPEND DEPEND_LIBS external_onnxruntime)
|
||||||
if(WITH_GPU)
|
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_LITE_BACKEND)
|
if(ENABLE_LITE_BACKEND)
|
||||||
@@ -329,20 +328,13 @@ if(WITH_GPU)
|
|||||||
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
|
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
|
||||||
endif()
|
endif()
|
||||||
list(APPEND DEPEND_LIBS ${CUDA_LIB})
|
list(APPEND DEPEND_LIBS ${CUDA_LIB})
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Whether to build CUDA source files in fastdeploy
|
# build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
||||||
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
enable_language(CUDA)
|
||||||
if(WITH_GPU)
|
message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}, version: "
|
||||||
set(BUILD_CUDA_SRC ON)
|
"${CMAKE_CUDA_COMPILER_ID} ${CMAKE_CUDA_COMPILER_VERSION}")
|
||||||
enable_language(CUDA)
|
include(${PROJECT_SOURCE_DIR}/cmake/cuda.cmake)
|
||||||
message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}, version: "
|
endif()
|
||||||
"${CMAKE_CUDA_COMPILER_ID} ${CMAKE_CUDA_COMPILER_VERSION}")
|
|
||||||
include(${PROJECT_SOURCE_DIR}/cmake/cuda.cmake)
|
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${FDTENSOR_FUNC_CUDA_SRCS})
|
|
||||||
else()
|
|
||||||
set(BUILD_CUDA_SRC OFF)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(WITH_IPU)
|
if(WITH_IPU)
|
||||||
@@ -383,7 +375,6 @@ if(ENABLE_TRT_BACKEND)
|
|||||||
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
||||||
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
||||||
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
|
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
|
|
||||||
|
|
||||||
if(NOT BUILD_ON_JETSON AND TRT_DIRECTORY)
|
if(NOT BUILD_ON_JETSON AND TRT_DIRECTORY)
|
||||||
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")
|
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")
|
||||||
@@ -418,10 +409,6 @@ if(ENABLE_VISION)
|
|||||||
add_definitions(-DENABLE_VISION_VISUALIZE)
|
add_definitions(-DENABLE_VISION_VISUALIZE)
|
||||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp)
|
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp)
|
||||||
list(APPEND DEPEND_LIBS yaml-cpp)
|
list(APPEND DEPEND_LIBS yaml-cpp)
|
||||||
if(BUILD_CUDA_SRC)
|
|
||||||
add_definitions(-DENABLE_CUDA_PREPROCESS)
|
|
||||||
list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS})
|
|
||||||
endif()
|
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS})
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS})
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PIPELINE_SRCS})
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PIPELINE_SRCS})
|
||||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
|
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
|
||||||
@@ -483,7 +470,7 @@ elseif(ANDROID)
|
|||||||
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL})
|
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL})
|
||||||
elseif(MSVC)
|
elseif(MSVC)
|
||||||
else()
|
else()
|
||||||
if(BUILD_CUDA_SRC)
|
if(WITH_GPU)
|
||||||
set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
||||||
set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS
|
set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS
|
||||||
"$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>")
|
"$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>")
|
||||||
|
@@ -29,7 +29,6 @@ function(fastdeploy_summary)
|
|||||||
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
|
message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}")
|
||||||
message(STATUS "")
|
message(STATUS "")
|
||||||
message(STATUS " FastDeploy version : ${FASTDEPLOY_VERSION}")
|
message(STATUS " FastDeploy version : ${FASTDEPLOY_VERSION}")
|
||||||
message(STATUS " Paddle2ONNX version : ${PADDLE2ONNX_VERSION}")
|
|
||||||
message(STATUS " ENABLE_ORT_BACKEND : ${ENABLE_ORT_BACKEND}")
|
message(STATUS " ENABLE_ORT_BACKEND : ${ENABLE_ORT_BACKEND}")
|
||||||
message(STATUS " ENABLE_RKNPU2_BACKEND : ${ENABLE_RKNPU2_BACKEND}")
|
message(STATUS " ENABLE_RKNPU2_BACKEND : ${ENABLE_RKNPU2_BACKEND}")
|
||||||
message(STATUS " ENABLE_SOPHGO_BACKEND : ${ENABLE_SOPHGO_BACKEND}")
|
message(STATUS " ENABLE_SOPHGO_BACKEND : ${ENABLE_SOPHGO_BACKEND}")
|
||||||
@@ -38,6 +37,7 @@ function(fastdeploy_summary)
|
|||||||
message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}")
|
message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}")
|
||||||
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
|
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
|
||||||
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
|
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
|
||||||
|
message(STATUS " WITH_GPU : ${WITH_GPU}")
|
||||||
message(STATUS " WITH_ASCEND : ${WITH_ASCEND}")
|
message(STATUS " WITH_ASCEND : ${WITH_ASCEND}")
|
||||||
message(STATUS " WITH_TIMVX : ${WITH_TIMVX}")
|
message(STATUS " WITH_TIMVX : ${WITH_TIMVX}")
|
||||||
message(STATUS " WITH_KUNLUNXIN : ${WITH_KUNLUNXIN}")
|
message(STATUS " WITH_KUNLUNXIN : ${WITH_KUNLUNXIN}")
|
||||||
@@ -54,16 +54,12 @@ function(fastdeploy_summary)
|
|||||||
message(STATUS " OpenVINO version : ${OPENVINO_VERSION}")
|
message(STATUS " OpenVINO version : ${OPENVINO_VERSION}")
|
||||||
endif()
|
endif()
|
||||||
if(WITH_GPU)
|
if(WITH_GPU)
|
||||||
message(STATUS " WITH_GPU : ${WITH_GPU}")
|
|
||||||
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
|
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
|
||||||
message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}")
|
message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}")
|
||||||
message(STATUS " BUILD_CUDA_SRC : ${BUILD_CUDA_SRC}")
|
|
||||||
endif()
|
endif()
|
||||||
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
|
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
|
||||||
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
|
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
|
||||||
message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}")
|
message(STATUS " ENABLE_ENCRYPTION : ${ENABLE_ENCRYPTION}")
|
||||||
message(STATUS " ENABLE_DEBUG : ${ENABLE_DEBUG}")
|
|
||||||
message(STATUS " ENABLE_VISION_VISUALIZE : ${ENABLE_VISION_VISUALIZE}")
|
|
||||||
if(ANDROID)
|
if(ANDROID)
|
||||||
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
|
message(STATUS " ANDROID_ABI : ${ANDROID_ABI}")
|
||||||
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
|
message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}")
|
||||||
|
@@ -12,8 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifdef WITH_GPU
|
||||||
#include "fastdeploy/function/cuda_cast.h"
|
#include "fastdeploy/function/cuda_cast.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace function {
|
namespace function {
|
||||||
template <typename T_IN, typename T_OUT>
|
template <typename T_IN, typename T_OUT>
|
||||||
@@ -30,13 +30,11 @@ void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) {
|
|||||||
if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) {
|
if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) {
|
||||||
CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>(
|
CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>(
|
||||||
reinterpret_cast<int64_t*>(const_cast<void*>(in.Data())),
|
reinterpret_cast<int64_t*>(const_cast<void*>(in.Data())),
|
||||||
reinterpret_cast<int32_t*>(out->MutableData()),
|
reinterpret_cast<int32_t*>(out->MutableData()), jobs);
|
||||||
jobs);
|
|
||||||
} else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) {
|
} else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) {
|
||||||
CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>(
|
CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>(
|
||||||
reinterpret_cast<int32_t*>(const_cast<void*>(in.Data())),
|
reinterpret_cast<int32_t*>(const_cast<void*>(in.Data())),
|
||||||
reinterpret_cast<int64_t*>(out->MutableData()),
|
reinterpret_cast<int64_t*>(out->MutableData()), jobs);
|
||||||
jobs);
|
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(false, "CudaCast only support input INT64, output INT32.");
|
FDASSERT(false, "CudaCast only support input INT64, output INT32.");
|
||||||
}
|
}
|
||||||
@@ -44,3 +42,4 @@ void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) {
|
|||||||
|
|
||||||
} // namespace function
|
} // namespace function
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
#endif
|
||||||
|
@@ -1,3 +1,19 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifdef WITH_GPU
|
||||||
|
|
||||||
#include "adaptive_pool2d_kernel.h"
|
#include "adaptive_pool2d_kernel.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
@@ -53,4 +69,5 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
|||||||
input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]),
|
input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]),
|
||||||
int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg);
|
int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg);
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
#endif
|
||||||
|
@@ -12,14 +12,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifdef WITH_GPU
|
||||||
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
|
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
|
|
||||||
__global__ void NormalizeAndPermuteKernel(
|
__global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst,
|
||||||
uint8_t* src, float* dst, const float* alpha, const float* beta,
|
const float* alpha, const float* beta,
|
||||||
int num_channel, bool swap_rb, int edge) {
|
int num_channel, bool swap_rb,
|
||||||
|
int edge) {
|
||||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (idx >= edge) return;
|
if (idx >= edge) return;
|
||||||
|
|
||||||
@@ -38,8 +40,8 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
|||||||
cv::Mat* im = mat->GetOpenCVMat();
|
cv::Mat* im = mat->GetOpenCVMat();
|
||||||
std::string buf_name = Name() + "_src";
|
std::string buf_name = Name() + "_src";
|
||||||
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
|
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
|
||||||
FDTensor* src = UpdateAndGetReusedBuffer(shape, im->type(), buf_name,
|
FDTensor* src =
|
||||||
Device::GPU);
|
UpdateAndGetReusedBuffer(shape, im->type(), buf_name, Device::GPU);
|
||||||
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
|
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
|
||||||
cudaMemcpyHostToDevice) == 0,
|
cudaMemcpyHostToDevice) == 0,
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
@@ -58,7 +60,7 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
|||||||
|
|
||||||
buf_name = Name() + "_beta";
|
buf_name = Name() + "_beta";
|
||||||
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
|
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
|
||||||
buf_name, Device::GPU);
|
buf_name, Device::GPU);
|
||||||
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
|
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
|
||||||
cudaMemcpyHostToDevice) == 0,
|
cudaMemcpyHostToDevice) == 0,
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
@@ -80,3 +82,4 @@ bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
|||||||
|
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
#endif
|
||||||
|
53
fastdeploy/vision/detection/contrib/yolov5lite.cc
Executable file → Normal file
53
fastdeploy/vision/detection/contrib/yolov5lite.cc
Executable file → Normal file
@@ -13,11 +13,12 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
|
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
|
||||||
|
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -89,8 +90,8 @@ YOLOv5Lite::YOLOv5Lite(const std::string& model_file,
|
|||||||
const RuntimeOption& custom_option,
|
const RuntimeOption& custom_option,
|
||||||
const ModelFormat& model_format) {
|
const ModelFormat& model_format) {
|
||||||
if (model_format == ModelFormat::ONNX) {
|
if (model_format == ModelFormat::ONNX) {
|
||||||
valid_cpu_backends = {Backend::ORT};
|
valid_cpu_backends = {Backend::ORT};
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
||||||
} else {
|
} else {
|
||||||
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
|
||||||
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
|
||||||
@@ -99,13 +100,13 @@ YOLOv5Lite::YOLOv5Lite(const std::string& model_file,
|
|||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option.model_file = model_file;
|
runtime_option.model_file = model_file;
|
||||||
runtime_option.params_file = params_file;
|
runtime_option.params_file = params_file;
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
cudaSetDevice(runtime_option.device_id);
|
cudaSetDevice(runtime_option.device_id);
|
||||||
cudaStream_t stream;
|
cudaStream_t stream;
|
||||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||||
cuda_stream_ = reinterpret_cast<void*>(stream);
|
cuda_stream_ = reinterpret_cast<void*>(stream);
|
||||||
runtime_option.SetExternalStream(cuda_stream_);
|
runtime_option.SetExternalStream(cuda_stream_);
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,14 +149,14 @@ bool YOLOv5Lite::Initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
YOLOv5Lite::~YOLOv5Lite() {
|
YOLOv5Lite::~YOLOv5Lite() {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
if (use_cuda_preprocessing_) {
|
if (use_cuda_preprocessing_) {
|
||||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
||||||
}
|
}
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv5Lite::Preprocess(
|
bool YOLOv5Lite::Preprocess(
|
||||||
@@ -199,28 +200,33 @@ bool YOLOv5Lite::Preprocess(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) {
|
void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
use_cuda_preprocessing_ = true;
|
use_cuda_preprocessing_ = true;
|
||||||
is_scale_up = true;
|
is_scale_up = true;
|
||||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||||
// prepare input data cache in GPU pinned memory
|
// prepare input data cache in GPU pinned memory
|
||||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_,
|
||||||
|
max_image_size * 3));
|
||||||
// prepare input data cache in GPU device memory
|
// prepare input data cache in GPU device memory
|
||||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
CUDA_CHECK(
|
||||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||||
|
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_,
|
||||||
|
3 * size[0] * size[1] * sizeof(float)));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl;
|
||||||
<< std::endl;
|
|
||||||
use_cuda_preprocessing_ = false;
|
use_cuda_preprocessing_ = false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
|
bool YOLOv5Lite::CudaPreprocess(
|
||||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
Mat* mat, FDTensor* output,
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||||
|
#ifdef WITH_GPU
|
||||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
FDERROR << "Preprocessing with CUDA is only available when the arguments "
|
||||||
|
"satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)."
|
||||||
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,14 +240,15 @@ bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
|
|||||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||||
input_img_cuda_buffer_host_,
|
input_img_cuda_buffer_host_, src_img_buf_size,
|
||||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
cudaMemcpyHostToDevice, stream));
|
||||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||||
size[0], size[1], padding_value, stream);
|
size[0], size[1], padding_value, stream);
|
||||||
|
|
||||||
// Record output shape of preprocessed image
|
// Record output shape of preprocessed image
|
||||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
(*im_info)["output_shape"] = {static_cast<float>(size[0]),
|
||||||
|
static_cast<float>(size[1])};
|
||||||
|
|
||||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||||
input_tensor_cuda_buffer_device_);
|
input_tensor_cuda_buffer_device_);
|
||||||
@@ -251,7 +258,7 @@ bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
|
|||||||
#else
|
#else
|
||||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv5Lite::PostprocessWithDecode(
|
bool YOLOv5Lite::PostprocessWithDecode(
|
||||||
|
21
fastdeploy/vision/detection/contrib/yolov6.cc
Executable file → Normal file
21
fastdeploy/vision/detection/contrib/yolov6.cc
Executable file → Normal file
@@ -16,9 +16,9 @@
|
|||||||
|
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
@@ -79,13 +79,13 @@ YOLOv6::YOLOv6(const std::string& model_file, const std::string& params_file,
|
|||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option.model_file = model_file;
|
runtime_option.model_file = model_file;
|
||||||
runtime_option.params_file = params_file;
|
runtime_option.params_file = params_file;
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
cudaSetDevice(runtime_option.device_id);
|
cudaSetDevice(runtime_option.device_id);
|
||||||
cudaStream_t stream;
|
cudaStream_t stream;
|
||||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||||
cuda_stream_ = reinterpret_cast<void*>(stream);
|
cuda_stream_ = reinterpret_cast<void*>(stream);
|
||||||
runtime_option.SetExternalStream(cuda_stream_);
|
runtime_option.SetExternalStream(cuda_stream_);
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,14 +123,14 @@ bool YOLOv6::Initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
YOLOv6::~YOLOv6() {
|
YOLOv6::~YOLOv6() {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
if (use_cuda_preprocessing_) {
|
if (use_cuda_preprocessing_) {
|
||||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
||||||
}
|
}
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
|
bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
|
||||||
@@ -173,7 +173,7 @@ bool YOLOv6::Preprocess(Mat* mat, FDTensor* output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void YOLOv6::UseCudaPreprocessing(int max_image_size) {
|
void YOLOv6::UseCudaPreprocessing(int max_image_size) {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
use_cuda_preprocessing_ = true;
|
use_cuda_preprocessing_ = true;
|
||||||
is_scale_up = true;
|
is_scale_up = true;
|
||||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||||
@@ -187,8 +187,7 @@ void YOLOv6::UseCudaPreprocessing(int max_image_size) {
|
|||||||
3 * size[0] * size[1] * sizeof(float)));
|
3 * size[0] * size[1] * sizeof(float)));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl;
|
||||||
<< std::endl;
|
|
||||||
use_cuda_preprocessing_ = false;
|
use_cuda_preprocessing_ = false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -196,7 +195,7 @@ void YOLOv6::UseCudaPreprocessing(int max_image_size) {
|
|||||||
bool YOLOv6::CudaPreprocess(
|
bool YOLOv6::CudaPreprocess(
|
||||||
Mat* mat, FDTensor* output,
|
Mat* mat, FDTensor* output,
|
||||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||||
FDERROR << "Preprocessing with CUDA is only available when the arguments "
|
FDERROR << "Preprocessing with CUDA is only available when the arguments "
|
||||||
"satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)."
|
"satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)."
|
||||||
@@ -232,7 +231,7 @@ bool YOLOv6::CudaPreprocess(
|
|||||||
#else
|
#else
|
||||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv6::Postprocess(
|
bool YOLOv6::Postprocess(
|
||||||
|
49
fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
Executable file → Normal file
49
fastdeploy/vision/detection/contrib/yolov7end2end_trt.cc
Executable file → Normal file
@@ -13,11 +13,12 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
|
||||||
|
|
||||||
#include "fastdeploy/utils/perf.h"
|
#include "fastdeploy/utils/perf.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
@@ -88,13 +89,13 @@ YOLOv7End2EndTRT::YOLOv7End2EndTRT(const std::string& model_file,
|
|||||||
runtime_option.backend = Backend::TRT;
|
runtime_option.backend = Backend::TRT;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
cudaSetDevice(runtime_option.device_id);
|
cudaSetDevice(runtime_option.device_id);
|
||||||
cudaStream_t stream;
|
cudaStream_t stream;
|
||||||
CUDA_CHECK(cudaStreamCreate(&stream));
|
CUDA_CHECK(cudaStreamCreate(&stream));
|
||||||
cuda_stream_ = reinterpret_cast<void*>(stream);
|
cuda_stream_ = reinterpret_cast<void*>(stream);
|
||||||
runtime_option.SetExternalStream(cuda_stream_);
|
runtime_option.SetExternalStream(cuda_stream_);
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
initialized = Initialize();
|
initialized = Initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,14 +132,14 @@ bool YOLOv7End2EndTRT::Initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
YOLOv7End2EndTRT::~YOLOv7End2EndTRT() {
|
YOLOv7End2EndTRT::~YOLOv7End2EndTRT() {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
if (use_cuda_preprocessing_) {
|
if (use_cuda_preprocessing_) {
|
||||||
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
|
||||||
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
|
||||||
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
CUDA_CHECK(cudaStreamDestroy(reinterpret_cast<cudaStream_t>(cuda_stream_)));
|
||||||
}
|
}
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv7End2EndTRT::Preprocess(
|
bool YOLOv7End2EndTRT::Preprocess(
|
||||||
@@ -173,28 +174,33 @@ bool YOLOv7End2EndTRT::Preprocess(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void YOLOv7End2EndTRT::UseCudaPreprocessing(int max_image_size) {
|
void YOLOv7End2EndTRT::UseCudaPreprocessing(int max_image_size) {
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
#ifdef WITH_GPU
|
||||||
use_cuda_preprocessing_ = true;
|
use_cuda_preprocessing_ = true;
|
||||||
is_scale_up = true;
|
is_scale_up = true;
|
||||||
if (input_img_cuda_buffer_host_ == nullptr) {
|
if (input_img_cuda_buffer_host_ == nullptr) {
|
||||||
// prepare input data cache in GPU pinned memory
|
// prepare input data cache in GPU pinned memory
|
||||||
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
|
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_,
|
||||||
|
max_image_size * 3));
|
||||||
// prepare input data cache in GPU device memory
|
// prepare input data cache in GPU device memory
|
||||||
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
CUDA_CHECK(
|
||||||
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
|
cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
|
||||||
|
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_,
|
||||||
|
3 * size[0] * size[1] * sizeof(float)));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
|
FDWARNING << "The FastDeploy didn't compile with WITH_GPU=ON." << std::endl;
|
||||||
<< std::endl;
|
|
||||||
use_cuda_preprocessing_ = false;
|
use_cuda_preprocessing_ = false;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output,
|
bool YOLOv7End2EndTRT::CudaPreprocess(
|
||||||
std::map<std::string, std::array<float, 2>>* im_info) {
|
Mat* mat, FDTensor* output,
|
||||||
#ifdef ENABLE_CUDA_PREPROCESS
|
std::map<std::string, std::array<float, 2>>* im_info) {
|
||||||
|
#ifdef WITH_GPU
|
||||||
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
|
||||||
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
|
FDERROR << "Preprocessing with CUDA is only available when the arguments "
|
||||||
|
"satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)."
|
||||||
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,14 +214,15 @@ bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output,
|
|||||||
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
|
||||||
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
|
||||||
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
|
||||||
input_img_cuda_buffer_host_,
|
input_img_cuda_buffer_host_, src_img_buf_size,
|
||||||
src_img_buf_size, cudaMemcpyHostToDevice, stream));
|
cudaMemcpyHostToDevice, stream));
|
||||||
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
|
||||||
mat->Height(), input_tensor_cuda_buffer_device_,
|
mat->Height(), input_tensor_cuda_buffer_device_,
|
||||||
size[0], size[1], padding_value, stream);
|
size[0], size[1], padding_value, stream);
|
||||||
|
|
||||||
// Record output shape of preprocessed image
|
// Record output shape of preprocessed image
|
||||||
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};
|
(*im_info)["output_shape"] = {static_cast<float>(size[0]),
|
||||||
|
static_cast<float>(size[1])};
|
||||||
|
|
||||||
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
|
||||||
input_tensor_cuda_buffer_device_);
|
input_tensor_cuda_buffer_device_);
|
||||||
@@ -225,7 +232,7 @@ bool YOLOv7End2EndTRT::CudaPreprocess(Mat* mat, FDTensor* output,
|
|||||||
#else
|
#else
|
||||||
FDERROR << "CUDA src code was not enabled." << std::endl;
|
FDERROR << "CUDA src code was not enabled." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
#endif // ENABLE_CUDA_PREPROCESS
|
#endif // WITH_GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
bool YOLOv7End2EndTRT::Postprocess(
|
bool YOLOv7End2EndTRT::Postprocess(
|
||||||
|
@@ -21,9 +21,11 @@
|
|||||||
// \brief
|
// \brief
|
||||||
// \author Qi Liu, Xinyu Wang
|
// \author Qi Liu, Xinyu Wang
|
||||||
|
|
||||||
#include "fastdeploy/vision/utils/cuda_utils.h"
|
#ifdef WITH_GPU
|
||||||
#include <opencv2/opencv.hpp>
|
#include <opencv2/opencv.hpp>
|
||||||
|
|
||||||
|
#include "fastdeploy/vision/utils/cuda_utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
@@ -32,12 +34,11 @@ struct AffineMatrix {
|
|||||||
float value[6];
|
float value[6];
|
||||||
};
|
};
|
||||||
|
|
||||||
__global__ void YoloPreprocessCudaKernel(
|
__global__ void YoloPreprocessCudaKernel(
|
||||||
uint8_t* src, int src_line_size, int src_width,
|
uint8_t* src, int src_line_size, int src_width, int src_height, float* dst,
|
||||||
int src_height, float* dst, int dst_width,
|
int dst_width, int dst_height, uint8_t padding_color_b,
|
||||||
int dst_height, uint8_t padding_color_b,
|
uint8_t padding_color_g, uint8_t padding_color_r, AffineMatrix d2s,
|
||||||
uint8_t padding_color_g, uint8_t padding_color_r,
|
int edge) {
|
||||||
AffineMatrix d2s, int edge) {
|
|
||||||
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (position >= edge) return;
|
if (position >= edge) return;
|
||||||
|
|
||||||
@@ -91,7 +92,7 @@ __global__ void YoloPreprocessCudaKernel(
|
|||||||
c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2];
|
c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
// bgr to rgb
|
// bgr to rgb
|
||||||
float t = c2;
|
float t = c2;
|
||||||
c2 = c0;
|
c2 = c0;
|
||||||
c0 = t;
|
c0 = t;
|
||||||
@@ -111,16 +112,17 @@ __global__ void YoloPreprocessCudaKernel(
|
|||||||
*pdst_c2 = c2;
|
*pdst_c2 = c2;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaYoloPreprocess(
|
void CudaYoloPreprocess(uint8_t* src, int src_width, int src_height, float* dst,
|
||||||
uint8_t* src, int src_width, int src_height,
|
int dst_width, int dst_height,
|
||||||
float* dst, int dst_width, int dst_height,
|
const std::vector<float> padding_value,
|
||||||
const std::vector<float> padding_value, cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
AffineMatrix s2d, d2s;
|
AffineMatrix s2d, d2s;
|
||||||
float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width);
|
float scale =
|
||||||
|
std::min(dst_height / (float)src_height, dst_width / (float)src_width);
|
||||||
|
|
||||||
s2d.value[0] = scale;
|
s2d.value[0] = scale;
|
||||||
s2d.value[1] = 0;
|
s2d.value[1] = 0;
|
||||||
s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5;
|
s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5;
|
||||||
s2d.value[3] = 0;
|
s2d.value[3] = 0;
|
||||||
s2d.value[4] = scale;
|
s2d.value[4] = scale;
|
||||||
s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5;
|
s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5;
|
||||||
@@ -135,12 +137,11 @@ void CudaYoloPreprocess(
|
|||||||
int threads = 256;
|
int threads = 256;
|
||||||
int blocks = ceil(jobs / (float)threads);
|
int blocks = ceil(jobs / (float)threads);
|
||||||
YoloPreprocessCudaKernel<<<blocks, threads, 0, stream>>>(
|
YoloPreprocessCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||||
src, src_width * 3, src_width,
|
src, src_width * 3, src_width, src_height, dst, dst_width, dst_height,
|
||||||
src_height, dst, dst_width,
|
padding_value[0], padding_value[1], padding_value[2], d2s, jobs);
|
||||||
dst_height, padding_value[0], padding_value[1], padding_value[2], d2s, jobs);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
#endif
|
||||||
|
Reference in New Issue
Block a user