[Backend] TRT backend & PP-Infer backend support pinned memory (#403)

* TRT backend use pinned memory

* refine fd tensor pinned memory logic

* TRT enable pinned memory configurable

* paddle inference support pinned memory

* pinned memory pybindings

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
Wang Xinyu
2022-10-21 18:51:36 +08:00
committed by GitHub
parent 8dbc1f1d10
commit 43d86114d8
14 changed files with 120 additions and 18 deletions

View File

@@ -306,17 +306,21 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
SetInputs(inputs);
AllocateOutputsBuffer(outputs);
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
}
for (size_t i = 0; i < outputs->size(); ++i) {
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
outputs_buffer_[(*outputs)[i].name].data(),
outputs_device_buffer_[(*outputs)[i].name].data(),
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
stream_) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU.");
}
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
return true;
}
@@ -332,10 +336,10 @@ void TrtBackend::GetInputOutputInfo() {
auto dtype = engine_->getBindingDataType(i);
if (engine_->bindingIsInput(i)) {
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
inputs_buffer_[name] = FDDeviceBuffer(dtype);
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
} else {
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
outputs_buffer_[name] = FDDeviceBuffer(dtype);
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
}
}
bindings_.resize(num_binds);
@@ -357,30 +361,31 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
"please use INT32 input");
} else {
// no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
}
} else {
// Allocate input buffer memory
inputs_buffer_[item.name].resize(dims);
inputs_device_buffer_[item.name].resize(dims);
// copy from cpu to gpu
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
}
// binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data();
bindings_[idx] = inputs_device_buffer_[item.name].data();
}
}
@@ -399,15 +404,19 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
// set user's outputs info
std::vector<int64_t> shape(output_dims.d,
output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_desc_[i].name);
// Allocate output buffer memory
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
// binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
}
}