mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
Add trt max workspace setting (#308)
* add trt max workspace setting * fix set trt max workspace
This commit is contained in:
@@ -97,6 +97,8 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) {
|
|||||||
} else if (reader_dtype == 5) {
|
} else if (reader_dtype == 5) {
|
||||||
// regard int64 as int32
|
// regard int64 as int32
|
||||||
return nvinfer1::DataType::kINT32;
|
return nvinfer1::DataType::kINT32;
|
||||||
|
} else if (reader_dtype == 6) {
|
||||||
|
return nvinfer1::DataType::kHALF;
|
||||||
}
|
}
|
||||||
FDASSERT(false, "Received unexpected data type of %d", reader_dtype);
|
FDASSERT(false, "Received unexpected data type of %d", reader_dtype);
|
||||||
return nvinfer1::DataType::kFLOAT;
|
return nvinfer1::DataType::kFLOAT;
|
||||||
@@ -135,4 +137,4 @@ nvinfer1::Dims ToDims(const std::vector<int64_t>& vec) {
|
|||||||
return dims;
|
return dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -36,6 +36,7 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
&RuntimeOption::SetPaddleMKLDNNCacheSize)
|
&RuntimeOption::SetPaddleMKLDNNCacheSize)
|
||||||
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
|
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
|
||||||
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
|
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
|
||||||
|
.def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize)
|
||||||
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
||||||
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
||||||
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
||||||
|
@@ -263,6 +263,10 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
|
|||||||
trt_serialize_file = cache_file_path;
|
trt_serialize_file = cache_file_path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) {
|
||||||
|
trt_max_workspace_size = max_workspace_size;
|
||||||
|
}
|
||||||
|
|
||||||
bool Runtime::Init(const RuntimeOption& _option) {
|
bool Runtime::Init(const RuntimeOption& _option) {
|
||||||
option = _option;
|
option = _option;
|
||||||
if (option.model_format == ModelFormat::AUTOREC) {
|
if (option.model_format == ModelFormat::AUTOREC) {
|
||||||
|
@@ -106,6 +106,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
|
|
||||||
void SetTrtCacheFile(const std::string& cache_file_path);
|
void SetTrtCacheFile(const std::string& cache_file_path);
|
||||||
|
|
||||||
|
void SetTrtMaxWorkspaceSize(size_t trt_max_workspace_size);
|
||||||
|
|
||||||
Backend backend = Backend::UNKNOWN;
|
Backend backend = Backend::UNKNOWN;
|
||||||
// for cpu inference and preprocess
|
// for cpu inference and preprocess
|
||||||
// default will let the backend choose their own default value
|
// default will let the backend choose their own default value
|
||||||
|
@@ -125,6 +125,9 @@ class RuntimeOption:
|
|||||||
def disable_trt_fp16(self):
|
def disable_trt_fp16(self):
|
||||||
return self._option.disable_trt_fp16()
|
return self._option.disable_trt_fp16()
|
||||||
|
|
||||||
|
def set_trt_max_workspace_size(self, trt_max_workspace_size):
|
||||||
|
return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
attrs = dir(self._option)
|
attrs = dir(self._option)
|
||||||
message = "RuntimeOption(\n"
|
message = "RuntimeOption(\n"
|
||||||
|
Reference in New Issue
Block a user