mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[Backend] TRT backend & PP-Infer backend support pinned memory (#403)
* TRT backend use pinned memory * refine fd tensor pinned memory logic * TRT enable pinned memory configurable * paddle inference support pinned memory * pinned memory pybindings Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
|||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||||
|
option_ = option;
|
||||||
if (option.use_gpu) {
|
if (option.use_gpu) {
|
||||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
||||||
if (option.enable_trt) {
|
if (option.enable_trt) {
|
||||||
@@ -190,6 +191,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
outputs->resize(outputs_desc_.size());
|
outputs->resize(outputs_desc_.size());
|
||||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||||
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
|
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
|
||||||
|
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
|
||||||
CopyTensorToCpu(handle, &((*outputs)[i]));
|
CopyTensorToCpu(handle, &((*outputs)[i]));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@@ -53,6 +53,7 @@ struct PaddleBackendOption {
|
|||||||
int gpu_mem_init_size = 100;
|
int gpu_mem_init_size = 100;
|
||||||
// gpu device id
|
// gpu device id
|
||||||
int gpu_id = 0;
|
int gpu_id = 0;
|
||||||
|
bool enable_pinned_memory = false;
|
||||||
|
|
||||||
std::vector<std::string> delete_pass_names = {};
|
std::vector<std::string> delete_pass_names = {};
|
||||||
};
|
};
|
||||||
@@ -105,6 +106,7 @@ class PaddleBackend : public BaseBackend {
|
|||||||
std::map<std::string, std::vector<int>>* opt_shape) const;
|
std::map<std::string, std::vector<int>>* opt_shape) const;
|
||||||
void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option);
|
void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option);
|
||||||
#endif
|
#endif
|
||||||
|
PaddleBackendOption option_;
|
||||||
paddle_infer::Config config_;
|
paddle_infer::Config config_;
|
||||||
std::shared_ptr<paddle_infer::Predictor> predictor_;
|
std::shared_ptr<paddle_infer::Predictor> predictor_;
|
||||||
std::vector<TensorInfo> inputs_desc_;
|
std::vector<TensorInfo> inputs_desc_;
|
||||||
|
@@ -67,7 +67,7 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
|||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
auto tmp_shape = tensor->shape();
|
auto tmp_shape = tensor->shape();
|
||||||
shape.assign(tmp_shape.begin(), tmp_shape.end());
|
shape.assign(tmp_shape.begin(), tmp_shape.end());
|
||||||
fd_tensor->Allocate(shape, fd_dtype, tensor->name());
|
fd_tensor->Resize(shape, fd_dtype, tensor->name());
|
||||||
if (fd_tensor->dtype == FDDataType::FP32) {
|
if (fd_tensor->dtype == FDDataType::FP32) {
|
||||||
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
|
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
|
||||||
return;
|
return;
|
||||||
|
@@ -306,17 +306,21 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
|
|
||||||
SetInputs(inputs);
|
SetInputs(inputs);
|
||||||
AllocateOutputsBuffer(outputs);
|
AllocateOutputsBuffer(outputs);
|
||||||
|
|
||||||
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
|
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
|
||||||
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||||
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
||||||
outputs_buffer_[(*outputs)[i].name].data(),
|
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||||
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
||||||
stream_) == 0,
|
stream_) == 0,
|
||||||
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
||||||
}
|
}
|
||||||
|
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||||
|
"[ERROR] Error occurs while sync cuda stream.");
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,10 +336,10 @@ void TrtBackend::GetInputOutputInfo() {
|
|||||||
auto dtype = engine_->getBindingDataType(i);
|
auto dtype = engine_->getBindingDataType(i);
|
||||||
if (engine_->bindingIsInput(i)) {
|
if (engine_->bindingIsInput(i)) {
|
||||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||||
inputs_buffer_[name] = FDDeviceBuffer(dtype);
|
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||||
} else {
|
} else {
|
||||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||||
outputs_buffer_[name] = FDDeviceBuffer(dtype);
|
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bindings_.resize(num_binds);
|
bindings_.resize(num_binds);
|
||||||
@@ -357,30 +361,31 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
|||||||
"please use INT32 input");
|
"please use INT32 input");
|
||||||
} else {
|
} else {
|
||||||
// no copy
|
// no copy
|
||||||
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
|
inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Allocate input buffer memory
|
// Allocate input buffer memory
|
||||||
inputs_buffer_[item.name].resize(dims);
|
inputs_device_buffer_[item.name].resize(dims);
|
||||||
|
|
||||||
// copy from cpu to gpu
|
// copy from cpu to gpu
|
||||||
if (item.dtype == FDDataType::INT64) {
|
if (item.dtype == FDDataType::INT64) {
|
||||||
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
||||||
std::vector<int32_t> casted_data(data, data + item.Numel());
|
std::vector<int32_t> casted_data(data, data + item.Numel());
|
||||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
|
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||||
static_cast<void*>(casted_data.data()),
|
static_cast<void*>(casted_data.data()),
|
||||||
item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||||
stream_) == 0,
|
stream_) == 0,
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
|
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||||
|
item.Data(),
|
||||||
item.Nbytes(), cudaMemcpyHostToDevice,
|
item.Nbytes(), cudaMemcpyHostToDevice,
|
||||||
stream_) == 0,
|
stream_) == 0,
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// binding input buffer
|
// binding input buffer
|
||||||
bindings_[idx] = inputs_buffer_[item.name].data();
|
bindings_[idx] = inputs_device_buffer_[item.name].data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,15 +404,19 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
|||||||
"Cannot find output: %s of tensorrt network from the original model.",
|
"Cannot find output: %s of tensorrt network from the original model.",
|
||||||
outputs_desc_[i].name.c_str());
|
outputs_desc_[i].name.c_str());
|
||||||
auto ori_idx = iter->second;
|
auto ori_idx = iter->second;
|
||||||
|
|
||||||
// set user's outputs info
|
// set user's outputs info
|
||||||
std::vector<int64_t> shape(output_dims.d,
|
std::vector<int64_t> shape(output_dims.d,
|
||||||
output_dims.d + output_dims.nbDims);
|
output_dims.d + output_dims.nbDims);
|
||||||
|
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||||
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
|
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
|
||||||
outputs_desc_[i].name);
|
outputs_desc_[i].name);
|
||||||
|
|
||||||
// Allocate output buffer memory
|
// Allocate output buffer memory
|
||||||
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
|
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||||
|
|
||||||
// binding output buffer
|
// binding output buffer
|
||||||
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
|
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -70,6 +70,7 @@ struct TrtBackendOption {
|
|||||||
std::map<std::string, std::vector<int32_t>> min_shape;
|
std::map<std::string, std::vector<int32_t>> min_shape;
|
||||||
std::map<std::string, std::vector<int32_t>> opt_shape;
|
std::map<std::string, std::vector<int32_t>> opt_shape;
|
||||||
std::string serialize_file = "";
|
std::string serialize_file = "";
|
||||||
|
bool enable_pinned_memory = false;
|
||||||
|
|
||||||
// inside parameter, maybe remove next version
|
// inside parameter, maybe remove next version
|
||||||
bool remove_multiclass_nms_ = false;
|
bool remove_multiclass_nms_ = false;
|
||||||
@@ -118,8 +119,8 @@ class TrtBackend : public BaseBackend {
|
|||||||
std::vector<void*> bindings_;
|
std::vector<void*> bindings_;
|
||||||
std::vector<TrtValueInfo> inputs_desc_;
|
std::vector<TrtValueInfo> inputs_desc_;
|
||||||
std::vector<TrtValueInfo> outputs_desc_;
|
std::vector<TrtValueInfo> outputs_desc_;
|
||||||
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
|
std::map<std::string, FDDeviceBuffer> inputs_device_buffer_;
|
||||||
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
|
std::map<std::string, FDDeviceBuffer> outputs_device_buffer_;
|
||||||
|
|
||||||
std::string calibration_str_;
|
std::string calibration_str_;
|
||||||
|
|
||||||
|
@@ -206,6 +206,8 @@ class FDGenericBuffer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
|
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
|
||||||
|
using FDDeviceHostBuffer = FDGenericBuffer<FDDeviceHostAllocator,
|
||||||
|
FDDeviceHostFree>;
|
||||||
|
|
||||||
class FDTrtLogger : public nvinfer1::ILogger {
|
class FDTrtLogger : public nvinfer1::ILogger {
|
||||||
public:
|
public:
|
||||||
|
@@ -34,6 +34,12 @@ bool FDDeviceAllocator::operator()(void** ptr, size_t size) const {
|
|||||||
|
|
||||||
void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }
|
void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }
|
||||||
|
|
||||||
|
bool FDDeviceHostAllocator::operator()(void** ptr, size_t size) const {
|
||||||
|
return cudaMallocHost(ptr, size) == cudaSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FDDeviceHostFree::operator()(void* ptr) const { cudaFreeHost(ptr); }
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -45,6 +45,16 @@ class FASTDEPLOY_DECL FDDeviceFree {
|
|||||||
void operator()(void* ptr) const;
|
void operator()(void* ptr) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class FASTDEPLOY_DECL FDDeviceHostAllocator {
|
||||||
|
public:
|
||||||
|
bool operator()(void** ptr, size_t size) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FASTDEPLOY_DECL FDDeviceHostFree {
|
||||||
|
public:
|
||||||
|
void operator()(void* ptr) const;
|
||||||
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -207,9 +207,27 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
|||||||
"-DWITH_GPU=ON,"
|
"-DWITH_GPU=ON,"
|
||||||
"so this is an unexpected problem happend.");
|
"so this is an unexpected problem happend.");
|
||||||
#endif
|
#endif
|
||||||
|
} else {
|
||||||
|
if (is_pinned_memory) {
|
||||||
|
#ifdef WITH_GPU
|
||||||
|
size_t original_nbytes = Nbytes();
|
||||||
|
if (nbytes > original_nbytes) {
|
||||||
|
if (buffer_ != nullptr) {
|
||||||
|
FDDeviceHostFree()(buffer_);
|
||||||
|
}
|
||||||
|
FDDeviceHostAllocator()(&buffer_, nbytes);
|
||||||
|
}
|
||||||
|
return buffer_ != nullptr;
|
||||||
|
#else
|
||||||
|
FDASSERT(false,
|
||||||
|
"The FastDeploy FDTensor allocator didn't compile under "
|
||||||
|
"-DWITH_GPU=ON,"
|
||||||
|
"so this is an unexpected problem happend.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
buffer_ = realloc(buffer_, nbytes);
|
||||||
|
return buffer_ != nullptr;
|
||||||
}
|
}
|
||||||
buffer_ = realloc(buffer_, nbytes);
|
|
||||||
return buffer_ != nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void FDTensor::FreeFn() {
|
void FDTensor::FreeFn() {
|
||||||
@@ -220,7 +238,13 @@ void FDTensor::FreeFn() {
|
|||||||
FDDeviceFree()(buffer_);
|
FDDeviceFree()(buffer_);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
FDHostFree()(buffer_);
|
if (is_pinned_memory) {
|
||||||
|
#ifdef WITH_GPU
|
||||||
|
FDDeviceHostFree()(buffer_);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
FDHostFree()(buffer_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
buffer_ = nullptr;
|
buffer_ = nullptr;
|
||||||
}
|
}
|
||||||
@@ -231,7 +255,6 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
|
|||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
|
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
|
||||||
"[ERROR] Error occurs while copy memory from GPU to GPU");
|
"[ERROR] Error occurs while copy memory from GPU to GPU");
|
||||||
|
|
||||||
#else
|
#else
|
||||||
FDASSERT(false,
|
FDASSERT(false,
|
||||||
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
|
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
|
||||||
@@ -239,7 +262,19 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
|
|||||||
"an unexpected problem happend.");
|
"an unexpected problem happend.");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
std::memcpy(dst, src, nbytes);
|
if (is_pinned_memory) {
|
||||||
|
#ifdef WITH_GPU
|
||||||
|
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToHost) == 0,
|
||||||
|
"[ERROR] Error occurs while copy memory from host to host");
|
||||||
|
#else
|
||||||
|
FDASSERT(false,
|
||||||
|
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
|
||||||
|
"gpu buffer is "
|
||||||
|
"an unexpected problem happend.");
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
std::memcpy(dst, src, nbytes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -40,6 +40,10 @@ struct FASTDEPLOY_DECL FDTensor {
|
|||||||
// so we can skip data transfer, which may improve the efficience
|
// so we can skip data transfer, which may improve the efficience
|
||||||
Device device = Device::CPU;
|
Device device = Device::CPU;
|
||||||
|
|
||||||
|
// Whether the data buffer is in pinned memory, which is allocated
|
||||||
|
// with cudaMallocHost()
|
||||||
|
bool is_pinned_memory = false;
|
||||||
|
|
||||||
// if the external data is not on CPU, we use this temporary buffer
|
// if the external data is not on CPU, we use this temporary buffer
|
||||||
// to transfer data to CPU at some cases we need to visit the
|
// to transfer data to CPU at some cases we need to visit the
|
||||||
// other devices' data
|
// other devices' data
|
||||||
|
@@ -44,6 +44,8 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
||||||
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
||||||
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
||||||
|
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
|
||||||
|
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
|
||||||
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
|
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
|
||||||
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
|
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
|
||||||
.def_readwrite("model_file", &RuntimeOption::model_file)
|
.def_readwrite("model_file", &RuntimeOption::model_file)
|
||||||
@@ -200,6 +202,7 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.def("numel", &FDTensor::Numel)
|
.def("numel", &FDTensor::Numel)
|
||||||
.def("nbytes", &FDTensor::Nbytes)
|
.def("nbytes", &FDTensor::Nbytes)
|
||||||
.def_readwrite("name", &FDTensor::name)
|
.def_readwrite("name", &FDTensor::name)
|
||||||
|
.def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory)
|
||||||
.def_readonly("shape", &FDTensor::shape)
|
.def_readonly("shape", &FDTensor::shape)
|
||||||
.def_readonly("dtype", &FDTensor::dtype)
|
.def_readonly("dtype", &FDTensor::dtype)
|
||||||
.def_readonly("device", &FDTensor::device);
|
.def_readonly("device", &FDTensor::device);
|
||||||
|
@@ -356,6 +356,10 @@ void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }
|
|||||||
|
|
||||||
void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; }
|
void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; }
|
||||||
|
|
||||||
|
void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; }
|
||||||
|
|
||||||
|
void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; }
|
||||||
|
|
||||||
void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
|
void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
|
||||||
trt_serialize_file = cache_file_path;
|
trt_serialize_file = cache_file_path;
|
||||||
}
|
}
|
||||||
@@ -503,6 +507,7 @@ void Runtime::CreatePaddleBackend() {
|
|||||||
pd_option.gpu_id = option.device_id;
|
pd_option.gpu_id = option.device_id;
|
||||||
pd_option.delete_pass_names = option.pd_delete_pass_names;
|
pd_option.delete_pass_names = option.pd_delete_pass_names;
|
||||||
pd_option.cpu_thread_num = option.cpu_thread_num;
|
pd_option.cpu_thread_num = option.cpu_thread_num;
|
||||||
|
pd_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||||
#ifdef ENABLE_TRT_BACKEND
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
if (pd_option.use_gpu && option.pd_enable_trt) {
|
if (pd_option.use_gpu && option.pd_enable_trt) {
|
||||||
pd_option.enable_trt = true;
|
pd_option.enable_trt = true;
|
||||||
@@ -516,6 +521,7 @@ void Runtime::CreatePaddleBackend() {
|
|||||||
trt_option.min_shape = option.trt_min_shape;
|
trt_option.min_shape = option.trt_min_shape;
|
||||||
trt_option.opt_shape = option.trt_opt_shape;
|
trt_option.opt_shape = option.trt_opt_shape;
|
||||||
trt_option.serialize_file = option.trt_serialize_file;
|
trt_option.serialize_file = option.trt_serialize_file;
|
||||||
|
trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||||
pd_option.trt_option = trt_option;
|
pd_option.trt_option = trt_option;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -606,6 +612,7 @@ void Runtime::CreateTrtBackend() {
|
|||||||
trt_option.min_shape = option.trt_min_shape;
|
trt_option.min_shape = option.trt_min_shape;
|
||||||
trt_option.opt_shape = option.trt_opt_shape;
|
trt_option.opt_shape = option.trt_opt_shape;
|
||||||
trt_option.serialize_file = option.trt_serialize_file;
|
trt_option.serialize_file = option.trt_serialize_file;
|
||||||
|
trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||||
|
|
||||||
// TODO(jiangjiajun): inside usage, maybe remove this later
|
// TODO(jiangjiajun): inside usage, maybe remove this later
|
||||||
trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;
|
trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;
|
||||||
|
@@ -204,6 +204,15 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
*/
|
*/
|
||||||
void SetTrtCacheFile(const std::string& cache_file_path);
|
void SetTrtCacheFile(const std::string& cache_file_path);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
|
||||||
|
*/
|
||||||
|
void EnablePinnedMemory();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Disable pinned memory
|
||||||
|
*/
|
||||||
|
void DisablePinnedMemory();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Enable to collect shape in paddle trt backend
|
* @brief Enable to collect shape in paddle trt backend
|
||||||
@@ -223,6 +232,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
|
|
||||||
Device device = Device::CPU;
|
Device device = Device::CPU;
|
||||||
|
|
||||||
|
bool enable_pinned_memory = false;
|
||||||
|
|
||||||
// ======Only for ORT Backend========
|
// ======Only for ORT Backend========
|
||||||
// -1 means use default value by ort
|
// -1 means use default value by ort
|
||||||
// 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3:
|
// 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3:
|
||||||
|
@@ -319,6 +319,16 @@ class RuntimeOption:
|
|||||||
"""
|
"""
|
||||||
return self._option.disable_trt_fp16()
|
return self._option.disable_trt_fp16()
|
||||||
|
|
||||||
|
def enable_pinned_memory(self):
|
||||||
|
"""Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
|
||||||
|
"""
|
||||||
|
return self._option.enable_pinned_memory()
|
||||||
|
|
||||||
|
def disable_pinned_memory(self):
|
||||||
|
"""Disable pinned memory.
|
||||||
|
"""
|
||||||
|
return self._option.disable_pinned_memory()
|
||||||
|
|
||||||
def enable_paddle_to_trt(self):
|
def enable_paddle_to_trt(self):
|
||||||
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
|
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user