mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Backend] TRT cast GPU input from int64 to int32, output from int32 to int64, and Windows support building CUDA files (#426)
* TRT cast int64 to int32 * windows cmake build cuda src * fix windows cmake error when build cuda src * add a notice in windows gpu build doc * cmake add cuda std=11 * TRT cast output from int32 to int64 * nits * trt get original input output dtype
This commit is contained in:
@@ -92,16 +92,6 @@ if(BUILD_ON_JETSON)
|
||||
set(ENABLE_ORT_BACKEND ON)
|
||||
endif()
|
||||
|
||||
# Whether to build CUDA source files in fastdeploy
|
||||
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
||||
if(WITH_GPU AND UNIX)
|
||||
set(BUILD_CUDA_SRC ON)
|
||||
enable_language(CUDA)
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
|
||||
else()
|
||||
set(BUILD_CUDA_SRC OFF)
|
||||
endif()
|
||||
|
||||
# config GIT_URL with github mirrors to speed up dependent repos clone
|
||||
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
|
||||
if(NOT GIT_URL)
|
||||
@@ -177,6 +167,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h.
|
||||
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc)
|
||||
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
|
||||
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
|
||||
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
|
||||
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
|
||||
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
|
||||
@@ -320,6 +311,18 @@ if(WITH_GPU)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Whether to build CUDA source files in fastdeploy
|
||||
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
||||
if(WITH_GPU)
|
||||
set(BUILD_CUDA_SRC ON)
|
||||
enable_language(CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 11)
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
|
||||
list(APPEND ALL_DEPLOY_SRCS ${FDTENSOR_FUNC_CUDA_SRCS})
|
||||
else()
|
||||
set(BUILD_CUDA_SRC OFF)
|
||||
endif()
|
||||
|
||||
if(ENABLE_TRT_BACKEND)
|
||||
if(APPLE OR ANDROID OR IOS)
|
||||
message(FATAL_ERROR "Cannot enable tensorrt backend in mac/ios/android os, please set -DENABLE_TRT_BACKEND=OFF.")
|
||||
@@ -463,7 +466,7 @@ endif()
|
||||
set_target_properties(${LIBRARY_NAME} PROPERTIES VERSION ${FASTDEPLOY_VERSION})
|
||||
if(MSVC)
|
||||
# disable warnings for dll export
|
||||
target_compile_options(${LIBRARY_NAME} PRIVATE /wd4251)
|
||||
target_compile_options(${LIBRARY_NAME} PRIVATE "$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:/wd4251>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=/wd4251>")
|
||||
endif()
|
||||
target_link_libraries(${LIBRARY_NAME} ${DEPEND_LIBS})
|
||||
|
||||
|
@@ -48,6 +48,8 @@ Windows编译需要满足条件
|
||||
- cuda >= 11.2
|
||||
- cudnn >= 8.2
|
||||
|
||||
注意:安装CUDA时需要勾选`Visual Studio Integration`, 或者手动将`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\`文件夹下的4个文件复制到`C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`文件夹。否则执行cmake命令时可能会遇到`No CUDA toolset found`报错。
|
||||
|
||||
在Windows菜单中,找到`x64 Native Tools Command Prompt for VS 2019`打开,执行如下命令
|
||||
|
||||
```bat
|
||||
|
@@ -51,6 +51,8 @@ Prerequisite for Compiling on Windows:
|
||||
- cuda >= 11.2
|
||||
- cudnn >= 8.2
|
||||
|
||||
Notice: Make sure `Visual Studio Integration` is installed during CUDA installation, or manually copy the 4 files under `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\` into `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`. Otherwise, you may run into `No CUDA toolset found` error during cmake.
|
||||
|
||||
Launch the x64 Native Tools Command Prompt for VS 2019 from the Windows Start Menu and run the following commands:
|
||||
|
||||
```
|
||||
|
@@ -13,8 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/tensorrt/trt_backend.h"
|
||||
#include "fastdeploy/function/cuda_cast.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "NvInferRuntime.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
@@ -234,6 +236,7 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
inputs_desc_[i].name = name;
|
||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
||||
inputs_desc_[i].original_dtype = ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
|
||||
auto info = ShapeRangeInfo(shape);
|
||||
info.name = name;
|
||||
auto iter_min = option.min_shape.find(name);
|
||||
@@ -256,6 +259,8 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype =
|
||||
ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype);
|
||||
outputs_desc_[i].original_dtype =
|
||||
ReaderDtypeToFDDtype(onnx_reader.outputs[i].dtype);
|
||||
}
|
||||
|
||||
if (option_.external_stream_) {
|
||||
@@ -315,9 +320,29 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
// if the final output tensor's dtype is different from the model output tensor's dtype,
|
||||
// then we need cast the data to the final output's dtype
|
||||
auto model_output_dtype = GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
|
||||
if ((*outputs)[i].dtype != model_output_dtype) {
|
||||
FDTensor output_tensor;
|
||||
output_tensor.SetExternalData((*outputs)[i].shape, model_output_dtype,
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||
Device::GPU);
|
||||
|
||||
casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype,
|
||||
(*outputs)[i].name, Device::GPU);
|
||||
CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_);
|
||||
} else {
|
||||
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
|
||||
(*outputs)[i].shape, model_output_dtype,
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||
Device::GPU);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||
casted_output_tensors_[(*outputs)[i].name].Data(),
|
||||
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
||||
stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
||||
@@ -329,6 +354,17 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
|
||||
void TrtBackend::GetInputOutputInfo() {
|
||||
// Read the original dtypes from inputs_desc_ and outputs_desc_
|
||||
std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
|
||||
std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
|
||||
for (size_t i = 0; i < inputs_desc_.size(); ++i) {
|
||||
inputs_original_dtype_map[inputs_desc_[i].name] = inputs_desc_[i].original_dtype;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||
outputs_original_dtype_map[outputs_desc_[i].name] = outputs_desc_[i].original_dtype;
|
||||
}
|
||||
|
||||
// Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_
|
||||
std::vector<TrtValueInfo>().swap(inputs_desc_);
|
||||
std::vector<TrtValueInfo>().swap(outputs_desc_);
|
||||
inputs_desc_.clear();
|
||||
@@ -339,11 +375,14 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
auto shape = ToVec(engine_->getBindingDimensions(i));
|
||||
auto dtype = engine_->getBindingDataType(i);
|
||||
if (engine_->bindingIsInput(i)) {
|
||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||
auto original_dtype = inputs_original_dtype_map.count(name) ? inputs_original_dtype_map[name] : GetFDDataType(dtype);
|
||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
} else {
|
||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||
auto original_dtype = outputs_original_dtype_map.count(name) ? outputs_original_dtype_map[name] : GetFDDataType(dtype);
|
||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
casted_output_tensors_[name] = FDTensor();
|
||||
}
|
||||
}
|
||||
bindings_.resize(num_binds);
|
||||
@@ -358,11 +397,12 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
|
||||
if (item.device == Device::GPU) {
|
||||
if (item.dtype == FDDataType::INT64) {
|
||||
// TODO(liqi): cast int64 to int32
|
||||
// TRT don't support INT64
|
||||
FDASSERT(false,
|
||||
"TRT don't support INT64 input on GPU, "
|
||||
"please use INT32 input");
|
||||
inputs_device_buffer_[item.name].resize(dims);
|
||||
FDTensor input_tensor;
|
||||
input_tensor.SetExternalData(item.shape, FDDataType::INT32,
|
||||
inputs_device_buffer_[item.name].data(),
|
||||
Device::GPU);
|
||||
CudaCast(item, &input_tensor, stream_);
|
||||
} else {
|
||||
// no copy
|
||||
inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
|
||||
@@ -413,7 +453,7 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
|
||||
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
||||
outputs_desc_[i].name);
|
||||
|
||||
// Allocate output buffer memory
|
||||
@@ -629,7 +669,7 @@ TensorInfo TrtBackend::GetInputInfo(int index) {
|
||||
info.name = inputs_desc_[index].name;
|
||||
info.shape.assign(inputs_desc_[index].shape.begin(),
|
||||
inputs_desc_[index].shape.end());
|
||||
info.dtype = GetFDDataType(inputs_desc_[index].dtype);
|
||||
info.dtype = inputs_desc_[index].original_dtype;
|
||||
return info;
|
||||
}
|
||||
|
||||
@@ -649,7 +689,7 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
info.name = outputs_desc_[index].name;
|
||||
info.shape.assign(outputs_desc_[index].shape.begin(),
|
||||
outputs_desc_[index].shape.end());
|
||||
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
|
||||
info.dtype = outputs_desc_[index].original_dtype;
|
||||
return info;
|
||||
}
|
||||
|
||||
|
@@ -57,7 +57,8 @@ namespace fastdeploy {
|
||||
struct TrtValueInfo {
|
||||
std::string name;
|
||||
std::vector<int> shape;
|
||||
nvinfer1::DataType dtype;
|
||||
nvinfer1::DataType dtype; // dtype of TRT model
|
||||
FDDataType original_dtype; // dtype of original ONNX/Paddle model
|
||||
};
|
||||
|
||||
struct TrtBackendOption {
|
||||
@@ -141,6 +142,13 @@ class TrtBackend : public BaseBackend {
|
||||
// Also will update the range information while inferencing
|
||||
std::map<std::string, ShapeRangeInfo> shape_range_info_;
|
||||
|
||||
// If the final output tensor's dtype is different from the
|
||||
// model output tensor's dtype, then we need cast the data
|
||||
// to the final output's dtype.
|
||||
// E.g. When trt model output tensor is int32, but final tensor is int64
|
||||
// This map stores the casted tensors.
|
||||
std::map<std::string, FDTensor> casted_output_tensors_;
|
||||
|
||||
void GetInputOutputInfo();
|
||||
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
|
||||
bool BuildTrtEngine();
|
||||
|
@@ -104,6 +104,26 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) {
|
||||
return nvinfer1::DataType::kFLOAT;
|
||||
}
|
||||
|
||||
FDDataType ReaderDtypeToFDDtype(int reader_dtype) {
|
||||
if (reader_dtype == 0) {
|
||||
return FDDataType::FP32;
|
||||
} else if (reader_dtype == 1) {
|
||||
return FDDataType::FP64;
|
||||
} else if (reader_dtype == 2) {
|
||||
return FDDataType::UINT8;
|
||||
} else if (reader_dtype == 3) {
|
||||
return FDDataType::INT8;
|
||||
} else if (reader_dtype == 4) {
|
||||
return FDDataType::INT32;
|
||||
} else if (reader_dtype == 5) {
|
||||
return FDDataType::INT64;
|
||||
} else if (reader_dtype == 6) {
|
||||
return FDDataType::FP16;
|
||||
}
|
||||
FDASSERT(false, "Received unexpected data type of %d", reader_dtype);
|
||||
return FDDataType::FP32;
|
||||
}
|
||||
|
||||
std::vector<int> ToVec(const nvinfer1::Dims& dim) {
|
||||
std::vector<int> out(dim.d, dim.d + dim.nbDims);
|
||||
return out;
|
||||
|
@@ -55,6 +55,8 @@ FDDataType GetFDDataType(const nvinfer1::DataType& dtype);
|
||||
|
||||
nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype);
|
||||
|
||||
FDDataType ReaderDtypeToFDDtype(int reader_dtype);
|
||||
|
||||
std::vector<int> ToVec(const nvinfer1::Dims& dim);
|
||||
|
||||
template <typename T>
|
||||
@@ -153,6 +155,11 @@ class FDGenericBuffer {
|
||||
//!
|
||||
size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); }
|
||||
|
||||
//!
|
||||
//! \brief Returns the dtype of the buffer.
|
||||
//!
|
||||
nvinfer1::DataType dtype() const { return mType; }
|
||||
|
||||
//!
|
||||
//! \brief Set user memory buffer for TRT Buffer
|
||||
//!
|
||||
|
45
fastdeploy/function/cuda_cast.cu
Normal file
45
fastdeploy/function/cuda_cast.cu
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/function/cuda_cast.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void CudaCastKernel(const T_IN* in, T_OUT* out, int edge) {
|
||||
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (position >= edge) return;
|
||||
out[position] = (T_OUT)in[position];
|
||||
}
|
||||
|
||||
void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) {
|
||||
int jobs = in.Numel();
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) {
|
||||
CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(in.Data())),
|
||||
reinterpret_cast<int32_t*>(out->MutableData()),
|
||||
jobs);
|
||||
} else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) {
|
||||
CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<int32_t*>(const_cast<void*>(in.Data())),
|
||||
reinterpret_cast<int64_t*>(out->MutableData()),
|
||||
jobs);
|
||||
} else {
|
||||
FDASSERT(false, "CudaCast only support input INT64, output INT32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
29
fastdeploy/function/cuda_cast.h
Normal file
29
fastdeploy/function/cuda_cast.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
/** Cast the type of the data in GPU buffer.
|
||||
@param in The input tensor.
|
||||
@param out The output tensor
|
||||
@param stream CUDA stream
|
||||
*/
|
||||
FASTDEPLOY_DECL void CudaCast(const FDTensor& in, FDTensor* out,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace fastdeploy
|
Reference in New Issue
Block a user