mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 01:50:27 +08:00
[Bug Fix] fix trt backend page-locked error (#2095)
* [Bug Fix] fix trt backend page-locked error * Update trt_backend.cc
This commit is contained in:
@@ -470,15 +470,31 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
|||||||
if (item.dtype == FDDataType::INT64) {
|
if (item.dtype == FDDataType::INT64) {
|
||||||
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
||||||
std::vector<int32_t> casted_data(data, data + item.Numel());
|
std::vector<int32_t> casted_data(data, data + item.Numel());
|
||||||
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
// FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||||
|
// static_cast<void*>(casted_data.data()),
|
||||||
|
// item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||||
|
// stream_) == 0,
|
||||||
|
// "Error occurs while copy memory from CPU to GPU.");
|
||||||
|
// WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need page-locked host
|
||||||
|
// memory to avoid any overlap to occur. The page-locked feature need by cudaMemcpyAsync
|
||||||
|
// may not guarantee by FDTensor now. Reference:
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
|
||||||
|
FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
|
||||||
static_cast<void*>(casted_data.data()),
|
static_cast<void*>(casted_data.data()),
|
||||||
item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
item.Nbytes() / 2, cudaMemcpyHostToDevice) == 0,
|
||||||
stream_) == 0,
|
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
// FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||||
|
// item.Data(), item.Nbytes(),
|
||||||
|
// cudaMemcpyHostToDevice, stream_) == 0,
|
||||||
|
// "Error occurs while copy memory from CPU to GPU.");
|
||||||
|
// WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need page-locked host
|
||||||
|
// memory to avoid any overlap to occur. The page-locked feature need by cudaMemcpyAsync
|
||||||
|
// may not guarantee by FDTensor now. Reference:
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
|
||||||
|
FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
|
||||||
item.Data(), item.Nbytes(),
|
item.Data(), item.Nbytes(),
|
||||||
cudaMemcpyHostToDevice, stream_) == 0,
|
cudaMemcpyHostToDevice) == 0,
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user