diff --git a/CMakeLists.txt b/CMakeLists.txt index c57251d42..b194e0ba5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,16 +92,6 @@ if(BUILD_ON_JETSON) set(ENABLE_ORT_BACKEND ON) endif() -# Whether to build CUDA source files in fastdeploy -# CUDA source files include CUDA preprocessing, TRT plugins, etc. -if(WITH_GPU AND UNIX) - set(BUILD_CUDA_SRC ON) - enable_language(CUDA) - set(CUDA_PROPAGATE_HOST_FLAGS FALSE) -else() - set(BUILD_CUDA_SRC OFF) -endif() - # config GIT_URL with github mirrors to speed up dependent repos clone option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL}) if(NOT GIT_URL) @@ -177,6 +167,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h. configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc) file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc) file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc) +file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu) file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc) file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc) file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc) @@ -320,6 +311,18 @@ if(WITH_GPU) endif() endif() +# Whether to build CUDA source files in fastdeploy +# CUDA source files include CUDA preprocessing, TRT plugins, etc. +if(WITH_GPU) + set(BUILD_CUDA_SRC ON) + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 11) + set(CUDA_PROPAGATE_HOST_FLAGS FALSE) + list(APPEND ALL_DEPLOY_SRCS ${FDTENSOR_FUNC_CUDA_SRCS}) +else() + set(BUILD_CUDA_SRC OFF) +endif() + if(ENABLE_TRT_BACKEND) if(APPLE OR ANDROID OR IOS) message(FATAL_ERROR "Cannot enable tensorrt backend in mac/ios/android os, please set -DENABLE_TRT_BACKEND=OFF.") @@ -463,7 +466,7 @@ endif() set_target_properties(${LIBRARY_NAME} PROPERTIES VERSION ${FASTDEPLOY_VERSION}) if(MSVC) # disable warnings for dll export - target_compile_options(${LIBRARY_NAME} PRIVATE /wd4251) + target_compile_options(${LIBRARY_NAME} PRIVATE "$<$>:/wd4251>$<$>:-Xcompiler=/wd4251>") endif() target_link_libraries(${LIBRARY_NAME} ${DEPEND_LIBS}) diff --git a/docs/cn/build_and_install/gpu.md b/docs/cn/build_and_install/gpu.md index 0972f3498..ae1b46c4c 100644 --- a/docs/cn/build_and_install/gpu.md +++ b/docs/cn/build_and_install/gpu.md @@ -14,7 +14,7 @@ FastDeploy当前在GPU环境支持Paddle Inference、ONNX Runtime和TensorRT, ## C++ SDK编译安装 -### Linux +### Linux Linux上编译需满足 - gcc/g++ >= 5.4(推荐8.2) @@ -48,6 +48,8 @@ Windows编译需要满足条件 - cuda >= 11.2 - cudnn >= 8.2 +注意:安装CUDA时需要勾选`Visual Studio Integration`, 或者手动将`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\`文件夹下的4个文件复制到`C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`文件夹。否则执行cmake命令时可能会遇到`No CUDA toolset found`报错。 + 在Windows菜单中,找到`x64 Native Tools Command Prompt for VS 2019`打开,执行如下命令 ```bat diff --git a/docs/en/build_and_install/gpu.md b/docs/en/build_and_install/gpu.md index 9d3c642aa..7d3ad9674 100644 --- a/docs/en/build_and_install/gpu.md +++ b/docs/en/build_and_install/gpu.md @@ -10,7 +10,7 @@ FastDeploy supports Paddle Inference, ONNX Runtime and TensorRT in the GPU envir | TensorRT | Windows(x64)
Linux(x64) | Paddle/ONNX | Support GPU only, and compilation switch is `ENABLE_TRT_BACKEND`. The default is OFF | | OpenVINO | Windows(x64)
Linux(x64) | Paddle/ONNX | Support CPU only, and compilation switch is `ENABLE_OPENVINO_BACKEND`. The default is OFF | -Note: +Note: When the environment is GPU, please set `WITH_GPU` as ON and specify `CUDA_DIRECTORY`. If TensorRT integration is needed, please specify `TRT_DIRECTORY` as well. @@ -51,6 +51,8 @@ Prerequisite for Compiling on Windows: - cuda >= 11.2 - cudnn >= 8.2 +Notice: Make sure `Visual Studio Integration` is installed during CUDA installation, or manually copy the 4 files under `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\` into `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`. Otherwise, you may run into `No CUDA toolset found` error during cmake. + Launch the x64 Native Tools Command Prompt for VS 2019 from the Windows Start Menu and run the following commands: ``` diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index 563901254..ca1078559 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "fastdeploy/backends/tensorrt/trt_backend.h" +#include "fastdeploy/function/cuda_cast.h" #include +#include #include "NvInferRuntime.h" #include "fastdeploy/utils/utils.h" @@ -234,6 +236,7 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file, inputs_desc_[i].name = name; inputs_desc_[i].shape.assign(shape.begin(), shape.end()); inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype); + inputs_desc_[i].original_dtype = ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype); auto info = ShapeRangeInfo(shape); info.name = name; auto iter_min = option.min_shape.find(name); @@ -256,6 +259,8 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file, outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype); + outputs_desc_[i].original_dtype = + ReaderDtypeToFDDtype(onnx_reader.outputs[i].dtype); } if (option_.external_stream_) { @@ -315,9 +320,29 @@ bool TrtBackend::Infer(std::vector& inputs, FDERROR << "Failed to Infer with TensorRT." << std::endl; return false; } + for (size_t i = 0; i < outputs->size(); ++i) { + // if the final output tensor's dtype is different from the model output tensor's dtype, + // then we need cast the data to the final output's dtype + auto model_output_dtype = GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype()); + if ((*outputs)[i].dtype != model_output_dtype) { + FDTensor output_tensor; + output_tensor.SetExternalData((*outputs)[i].shape, model_output_dtype, + outputs_device_buffer_[(*outputs)[i].name].data(), + Device::GPU); + + casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype, + (*outputs)[i].name, Device::GPU); + CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_); + } else { + casted_output_tensors_[(*outputs)[i].name].SetExternalData( + (*outputs)[i].shape, model_output_dtype, + outputs_device_buffer_[(*outputs)[i].name].data(), + Device::GPU); + } + } for (size_t i = 0; i < outputs->size(); ++i) { FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(), - outputs_device_buffer_[(*outputs)[i].name].data(), + casted_output_tensors_[(*outputs)[i].name].Data(), (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost, stream_) == 0, "[ERROR] Error occurs while copy memory from GPU to CPU."); @@ -329,6 +354,17 @@ bool TrtBackend::Infer(std::vector& inputs, } void TrtBackend::GetInputOutputInfo() { + // Read the original dtypes from inputs_desc_ and outputs_desc_ + std::unordered_map inputs_original_dtype_map; + std::unordered_map outputs_original_dtype_map; + for (size_t i = 0; i < inputs_desc_.size(); ++i) { + inputs_original_dtype_map[inputs_desc_[i].name] = inputs_desc_[i].original_dtype; + } + for (size_t i = 0; i < outputs_desc_.size(); ++i) { + outputs_original_dtype_map[outputs_desc_[i].name] = outputs_desc_[i].original_dtype; + } + + // Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_ std::vector().swap(inputs_desc_); std::vector().swap(outputs_desc_); inputs_desc_.clear(); @@ -339,11 +375,14 @@ void TrtBackend::GetInputOutputInfo() { auto shape = ToVec(engine_->getBindingDimensions(i)); auto dtype = engine_->getBindingDataType(i); if (engine_->bindingIsInput(i)) { - inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); + auto original_dtype = inputs_original_dtype_map.count(name) ? inputs_original_dtype_map[name] : GetFDDataType(dtype); + inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype}); inputs_device_buffer_[name] = FDDeviceBuffer(dtype); } else { - outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); + auto original_dtype = outputs_original_dtype_map.count(name) ? outputs_original_dtype_map[name] : GetFDDataType(dtype); + outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype}); outputs_device_buffer_[name] = FDDeviceBuffer(dtype); + casted_output_tensors_[name] = FDTensor(); } } bindings_.resize(num_binds); @@ -358,11 +397,12 @@ void TrtBackend::SetInputs(const std::vector& inputs) { if (item.device == Device::GPU) { if (item.dtype == FDDataType::INT64) { - // TODO(liqi): cast int64 to int32 - // TRT don't support INT64 - FDASSERT(false, - "TRT don't support INT64 input on GPU, " - "please use INT32 input"); + inputs_device_buffer_[item.name].resize(dims); + FDTensor input_tensor; + input_tensor.SetExternalData(item.shape, FDDataType::INT32, + inputs_device_buffer_[item.name].data(), + Device::GPU); + CudaCast(item, &input_tensor, stream_); } else { // no copy inputs_device_buffer_[item.name].SetExternalData(dims, item.Data()); @@ -413,7 +453,7 @@ void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { std::vector shape(output_dims.d, output_dims.d + output_dims.nbDims); (*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory; - (*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype), + (*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype, outputs_desc_[i].name); // Allocate output buffer memory @@ -629,7 +669,7 @@ TensorInfo TrtBackend::GetInputInfo(int index) { info.name = inputs_desc_[index].name; info.shape.assign(inputs_desc_[index].shape.begin(), inputs_desc_[index].shape.end()); - info.dtype = GetFDDataType(inputs_desc_[index].dtype); + info.dtype = inputs_desc_[index].original_dtype; return info; } @@ -649,7 +689,7 @@ TensorInfo TrtBackend::GetOutputInfo(int index) { info.name = outputs_desc_[index].name; info.shape.assign(outputs_desc_[index].shape.begin(), outputs_desc_[index].shape.end()); - info.dtype = GetFDDataType(outputs_desc_[index].dtype); + info.dtype = outputs_desc_[index].original_dtype; return info; } diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index 0aebba717..63782e9d0 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -57,7 +57,8 @@ namespace fastdeploy { struct TrtValueInfo { std::string name; std::vector shape; - nvinfer1::DataType dtype; + nvinfer1::DataType dtype; // dtype of TRT model + FDDataType original_dtype; // dtype of original ONNX/Paddle model }; struct TrtBackendOption { @@ -141,6 +142,13 @@ class TrtBackend : public BaseBackend { // Also will update the range information while inferencing std::map shape_range_info_; + // If the final output tensor's dtype is different from the + // model output tensor's dtype, then we need cast the data + // to the final output's dtype. + // E.g. When trt model output tensor is int32, but final tensor is int64 + // This map stores the casted tensors. + std::map casted_output_tensors_; + void GetInputOutputInfo(); bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer); bool BuildTrtEngine(); diff --git a/fastdeploy/backends/tensorrt/utils.cc b/fastdeploy/backends/tensorrt/utils.cc index 1347b0a4a..f0322bc94 100644 --- a/fastdeploy/backends/tensorrt/utils.cc +++ b/fastdeploy/backends/tensorrt/utils.cc @@ -104,6 +104,26 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) { return nvinfer1::DataType::kFLOAT; } +FDDataType ReaderDtypeToFDDtype(int reader_dtype) { + if (reader_dtype == 0) { + return FDDataType::FP32; + } else if (reader_dtype == 1) { + return FDDataType::FP64; + } else if (reader_dtype == 2) { + return FDDataType::UINT8; + } else if (reader_dtype == 3) { + return FDDataType::INT8; + } else if (reader_dtype == 4) { + return FDDataType::INT32; + } else if (reader_dtype == 5) { + return FDDataType::INT64; + } else if (reader_dtype == 6) { + return FDDataType::FP16; + } + FDASSERT(false, "Received unexpected data type of %d", reader_dtype); + return FDDataType::FP32; +} + std::vector ToVec(const nvinfer1::Dims& dim) { std::vector out(dim.d, dim.d + dim.nbDims); return out; diff --git a/fastdeploy/backends/tensorrt/utils.h b/fastdeploy/backends/tensorrt/utils.h index 7f2e7344b..af62c445e 100644 --- a/fastdeploy/backends/tensorrt/utils.h +++ b/fastdeploy/backends/tensorrt/utils.h @@ -55,6 +55,8 @@ FDDataType GetFDDataType(const nvinfer1::DataType& dtype); nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype); +FDDataType ReaderDtypeToFDDtype(int reader_dtype); + std::vector ToVec(const nvinfer1::Dims& dim); template @@ -153,6 +155,11 @@ class FDGenericBuffer { //! size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); } + //! + //! \brief Returns the dtype of the buffer. + //! + nvinfer1::DataType dtype() const { return mType; } + //! //! \brief Set user memory buffer for TRT Buffer //! diff --git a/fastdeploy/function/cuda_cast.cu b/fastdeploy/function/cuda_cast.cu new file mode 100644 index 000000000..5bdf337e5 --- /dev/null +++ b/fastdeploy/function/cuda_cast.cu @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/function/cuda_cast.h" + +namespace fastdeploy { + +template +__global__ void CudaCastKernel(const T_IN* in, T_OUT* out, int edge) { + int position = blockDim.x * blockIdx.x + threadIdx.x; + if (position >= edge) return; + out[position] = (T_OUT)in[position]; +} + +void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) { + int jobs = in.Numel(); + int threads = 256; + int blocks = ceil(jobs / (float)threads); + if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) { + CudaCastKernel<<>>( + reinterpret_cast(const_cast(in.Data())), + reinterpret_cast(out->MutableData()), + jobs); + } else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) { + CudaCastKernel<<>>( + reinterpret_cast(const_cast(in.Data())), + reinterpret_cast(out->MutableData()), + jobs); + } else { + FDASSERT(false, "CudaCast only support input INT64, output INT32."); + } +} + +} // namespace fastdeploy diff --git a/fastdeploy/function/cuda_cast.h b/fastdeploy/function/cuda_cast.h new file mode 100644 index 000000000..f467e78fe --- /dev/null +++ b/fastdeploy/function/cuda_cast.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { + +/** Cast the type of the data in GPU buffer. + @param in The input tensor. + @param out The output tensor + @param stream CUDA stream +*/ +FASTDEPLOY_DECL void CudaCast(const FDTensor& in, FDTensor* out, + cudaStream_t stream); + +} // namespace fastdeploy