[Backend] TRT backend & PP-Infer backend support pinned memory (#403)

* TRT backend use pinned memory

* refine fd tensor pinned memory logic

* TRT enable pinned memory configurable

* paddle inference support pinned memory

* pinned memory pybindings

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
Wang Xinyu
2022-10-21 18:51:36 +08:00
committed by GitHub
parent 8dbc1f1d10
commit 43d86114d8
14 changed files with 120 additions and 18 deletions

View File

@@ -19,6 +19,7 @@
namespace fastdeploy { namespace fastdeploy {
void PaddleBackend::BuildOption(const PaddleBackendOption& option) { void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option;
if (option.use_gpu) { if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
if (option.enable_trt) { if (option.enable_trt) {
@@ -190,6 +191,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
outputs->resize(outputs_desc_.size()); outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) { for (size_t i = 0; i < outputs_desc_.size(); ++i) {
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name); auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
CopyTensorToCpu(handle, &((*outputs)[i])); CopyTensorToCpu(handle, &((*outputs)[i]));
} }
return true; return true;

View File

@@ -53,6 +53,7 @@ struct PaddleBackendOption {
int gpu_mem_init_size = 100; int gpu_mem_init_size = 100;
// gpu device id // gpu device id
int gpu_id = 0; int gpu_id = 0;
bool enable_pinned_memory = false;
std::vector<std::string> delete_pass_names = {}; std::vector<std::string> delete_pass_names = {};
}; };
@@ -105,6 +106,7 @@ class PaddleBackend : public BaseBackend {
std::map<std::string, std::vector<int>>* opt_shape) const; std::map<std::string, std::vector<int>>* opt_shape) const;
void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option); void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option);
#endif #endif
PaddleBackendOption option_;
paddle_infer::Config config_; paddle_infer::Config config_;
std::shared_ptr<paddle_infer::Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
std::vector<TensorInfo> inputs_desc_; std::vector<TensorInfo> inputs_desc_;

View File

@@ -67,7 +67,7 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
std::vector<int64_t> shape; std::vector<int64_t> shape;
auto tmp_shape = tensor->shape(); auto tmp_shape = tensor->shape();
shape.assign(tmp_shape.begin(), tmp_shape.end()); shape.assign(tmp_shape.begin(), tmp_shape.end());
fd_tensor->Allocate(shape, fd_dtype, tensor->name()); fd_tensor->Resize(shape, fd_dtype, tensor->name());
if (fd_tensor->dtype == FDDataType::FP32) { if (fd_tensor->dtype == FDDataType::FP32) {
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData())); tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
return; return;

View File

@@ -306,17 +306,21 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
SetInputs(inputs); SetInputs(inputs);
AllocateOutputsBuffer(outputs); AllocateOutputsBuffer(outputs);
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl; FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false; return false;
} }
for (size_t i = 0; i < outputs->size(); ++i) { for (size_t i = 0; i < outputs->size(); ++i) {
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(), FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
outputs_buffer_[(*outputs)[i].name].data(), outputs_device_buffer_[(*outputs)[i].name].data(),
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost, (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
stream_) == 0, stream_) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU."); "[ERROR] Error occurs while copy memory from GPU to CPU.");
} }
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
return true; return true;
} }
@@ -332,10 +336,10 @@ void TrtBackend::GetInputOutputInfo() {
auto dtype = engine_->getBindingDataType(i); auto dtype = engine_->getBindingDataType(i);
if (engine_->bindingIsInput(i)) { if (engine_->bindingIsInput(i)) {
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
inputs_buffer_[name] = FDDeviceBuffer(dtype); inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
} else { } else {
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype}); outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
outputs_buffer_[name] = FDDeviceBuffer(dtype); outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
} }
} }
bindings_.resize(num_binds); bindings_.resize(num_binds);
@@ -357,30 +361,31 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
"please use INT32 input"); "please use INT32 input");
} else { } else {
// no copy // no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data()); inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
} }
} else { } else {
// Allocate input buffer memory // Allocate input buffer memory
inputs_buffer_[item.name].resize(dims); inputs_device_buffer_[item.name].resize(dims);
// copy from cpu to gpu // copy from cpu to gpu
if (item.dtype == FDDataType::INT64) { if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data())); int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel()); std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()), static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice, item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0, stream_) == 0,
"Error occurs while copy memory from CPU to GPU."); "Error occurs while copy memory from CPU to GPU.");
} else { } else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(), FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice, item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0, stream_) == 0,
"Error occurs while copy memory from CPU to GPU."); "Error occurs while copy memory from CPU to GPU.");
} }
} }
// binding input buffer // binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data(); bindings_[idx] = inputs_device_buffer_[item.name].data();
} }
} }
@@ -399,15 +404,19 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
"Cannot find output: %s of tensorrt network from the original model.", "Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str()); outputs_desc_[i].name.c_str());
auto ori_idx = iter->second; auto ori_idx = iter->second;
// set user's outputs info // set user's outputs info
std::vector<int64_t> shape(output_dims.d, std::vector<int64_t> shape(output_dims.d,
output_dims.d + output_dims.nbDims); output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype), (*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_desc_[i].name); outputs_desc_[i].name);
// Allocate output buffer memory // Allocate output buffer memory
outputs_buffer_[outputs_desc_[i].name].resize(output_dims); outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
// binding output buffer // binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data(); bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
} }
} }

View File

@@ -70,6 +70,7 @@ struct TrtBackendOption {
std::map<std::string, std::vector<int32_t>> min_shape; std::map<std::string, std::vector<int32_t>> min_shape;
std::map<std::string, std::vector<int32_t>> opt_shape; std::map<std::string, std::vector<int32_t>> opt_shape;
std::string serialize_file = ""; std::string serialize_file = "";
bool enable_pinned_memory = false;
// inside parameter, maybe remove next version // inside parameter, maybe remove next version
bool remove_multiclass_nms_ = false; bool remove_multiclass_nms_ = false;
@@ -118,8 +119,8 @@ class TrtBackend : public BaseBackend {
std::vector<void*> bindings_; std::vector<void*> bindings_;
std::vector<TrtValueInfo> inputs_desc_; std::vector<TrtValueInfo> inputs_desc_;
std::vector<TrtValueInfo> outputs_desc_; std::vector<TrtValueInfo> outputs_desc_;
std::map<std::string, FDDeviceBuffer> inputs_buffer_; std::map<std::string, FDDeviceBuffer> inputs_device_buffer_;
std::map<std::string, FDDeviceBuffer> outputs_buffer_; std::map<std::string, FDDeviceBuffer> outputs_device_buffer_;
std::string calibration_str_; std::string calibration_str_;

View File

@@ -206,6 +206,8 @@ class FDGenericBuffer {
}; };
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>; using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
using FDDeviceHostBuffer = FDGenericBuffer<FDDeviceHostAllocator,
FDDeviceHostFree>;
class FDTrtLogger : public nvinfer1::ILogger { class FDTrtLogger : public nvinfer1::ILogger {
public: public:

View File

@@ -34,6 +34,12 @@ bool FDDeviceAllocator::operator()(void** ptr, size_t size) const {
void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); } void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }
bool FDDeviceHostAllocator::operator()(void** ptr, size_t size) const {
return cudaMallocHost(ptr, size) == cudaSuccess;
}
void FDDeviceHostFree::operator()(void* ptr) const { cudaFreeHost(ptr); }
#endif #endif
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -45,6 +45,16 @@ class FASTDEPLOY_DECL FDDeviceFree {
void operator()(void* ptr) const; void operator()(void* ptr) const;
}; };
class FASTDEPLOY_DECL FDDeviceHostAllocator {
public:
bool operator()(void** ptr, size_t size) const;
};
class FASTDEPLOY_DECL FDDeviceHostFree {
public:
void operator()(void* ptr) const;
};
#endif #endif
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -207,9 +207,27 @@ bool FDTensor::ReallocFn(size_t nbytes) {
"-DWITH_GPU=ON," "-DWITH_GPU=ON,"
"so this is an unexpected problem happend."); "so this is an unexpected problem happend.");
#endif #endif
} else {
if (is_pinned_memory) {
#ifdef WITH_GPU
size_t original_nbytes = Nbytes();
if (nbytes > original_nbytes) {
if (buffer_ != nullptr) {
FDDeviceHostFree()(buffer_);
}
FDDeviceHostAllocator()(&buffer_, nbytes);
}
return buffer_ != nullptr;
#else
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
}
buffer_ = realloc(buffer_, nbytes);
return buffer_ != nullptr;
} }
buffer_ = realloc(buffer_, nbytes);
return buffer_ != nullptr;
} }
void FDTensor::FreeFn() { void FDTensor::FreeFn() {
@@ -220,7 +238,13 @@ void FDTensor::FreeFn() {
FDDeviceFree()(buffer_); FDDeviceFree()(buffer_);
#endif #endif
} else { } else {
FDHostFree()(buffer_); if (is_pinned_memory) {
#ifdef WITH_GPU
FDDeviceHostFree()(buffer_);
#endif
} else {
FDHostFree()(buffer_);
}
} }
buffer_ = nullptr; buffer_ = nullptr;
} }
@@ -231,7 +255,6 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
#ifdef WITH_GPU #ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0, FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
"[ERROR] Error occurs while copy memory from GPU to GPU"); "[ERROR] Error occurs while copy memory from GPU to GPU");
#else #else
FDASSERT(false, FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying " "The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
@@ -239,7 +262,19 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
"an unexpected problem happend."); "an unexpected problem happend.");
#endif #endif
} else { } else {
std::memcpy(dst, src, nbytes); if (is_pinned_memory) {
#ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToHost) == 0,
"[ERROR] Error occurs while copy memory from host to host");
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
"gpu buffer is "
"an unexpected problem happend.");
#endif
} else {
std::memcpy(dst, src, nbytes);
}
} }
} }

View File

@@ -40,6 +40,10 @@ struct FASTDEPLOY_DECL FDTensor {
// so we can skip data transfer, which may improve the efficience // so we can skip data transfer, which may improve the efficience
Device device = Device::CPU; Device device = Device::CPU;
// Whether the data buffer is in pinned memory, which is allocated
// with cudaMallocHost()
bool is_pinned_memory = false;
// if the external data is not on CPU, we use this temporary buffer // if the external data is not on CPU, we use this temporary buffer
// to transfer data to CPU at some cases we need to visit the // to transfer data to CPU at some cases we need to visit the
// other devices' data // other devices' data

View File

@@ -44,6 +44,8 @@ void BindRuntime(pybind11::module& m) {
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16) .def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16) .def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape) .def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape) .def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
.def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("model_file", &RuntimeOption::model_file)
@@ -200,6 +202,7 @@ void BindRuntime(pybind11::module& m) {
.def("numel", &FDTensor::Numel) .def("numel", &FDTensor::Numel)
.def("nbytes", &FDTensor::Nbytes) .def("nbytes", &FDTensor::Nbytes)
.def_readwrite("name", &FDTensor::name) .def_readwrite("name", &FDTensor::name)
.def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory)
.def_readonly("shape", &FDTensor::shape) .def_readonly("shape", &FDTensor::shape)
.def_readonly("dtype", &FDTensor::dtype) .def_readonly("dtype", &FDTensor::dtype)
.def_readonly("device", &FDTensor::device); .def_readonly("device", &FDTensor::device);

View File

@@ -356,6 +356,10 @@ void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }
void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; } void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; }
void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; }
void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; }
void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path; trt_serialize_file = cache_file_path;
} }
@@ -503,6 +507,7 @@ void Runtime::CreatePaddleBackend() {
pd_option.gpu_id = option.device_id; pd_option.gpu_id = option.device_id;
pd_option.delete_pass_names = option.pd_delete_pass_names; pd_option.delete_pass_names = option.pd_delete_pass_names;
pd_option.cpu_thread_num = option.cpu_thread_num; pd_option.cpu_thread_num = option.cpu_thread_num;
pd_option.enable_pinned_memory = option.enable_pinned_memory;
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
if (pd_option.use_gpu && option.pd_enable_trt) { if (pd_option.use_gpu && option.pd_enable_trt) {
pd_option.enable_trt = true; pd_option.enable_trt = true;
@@ -516,6 +521,7 @@ void Runtime::CreatePaddleBackend() {
trt_option.min_shape = option.trt_min_shape; trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape; trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file; trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_option = trt_option; pd_option.trt_option = trt_option;
} }
#endif #endif
@@ -606,6 +612,7 @@ void Runtime::CreateTrtBackend() {
trt_option.min_shape = option.trt_min_shape; trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape; trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file; trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;
// TODO(jiangjiajun): inside usage, maybe remove this later // TODO(jiangjiajun): inside usage, maybe remove this later
trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;

View File

@@ -204,6 +204,15 @@ struct FASTDEPLOY_DECL RuntimeOption {
*/ */
void SetTrtCacheFile(const std::string& cache_file_path); void SetTrtCacheFile(const std::string& cache_file_path);
/**
* @brief Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
*/
void EnablePinnedMemory();
/**
* @brief Disable pinned memory
*/
void DisablePinnedMemory();
/** /**
* @brief Enable to collect shape in paddle trt backend * @brief Enable to collect shape in paddle trt backend
@@ -223,6 +232,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
Device device = Device::CPU; Device device = Device::CPU;
bool enable_pinned_memory = false;
// ======Only for ORT Backend======== // ======Only for ORT Backend========
// -1 means use default value by ort // -1 means use default value by ort
// 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3: // 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3:

View File

@@ -319,6 +319,16 @@ class RuntimeOption:
""" """
return self._option.disable_trt_fp16() return self._option.disable_trt_fp16()
def enable_pinned_memory(self):
"""Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
"""
return self._option.enable_pinned_memory()
def disable_pinned_memory(self):
"""Disable pinned memory.
"""
return self._option.disable_pinned_memory()
def enable_paddle_to_trt(self): def enable_paddle_to_trt(self):
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead. """While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
""" """