Add trt max workspace setting (#308)

* add trt max workspace setting

* fix set trt max workspace
This commit is contained in:
Jack Zhou
2022-09-30 09:54:34 +08:00
committed by GitHub
parent 4a4c37aa97
commit dd365fb721
5 changed files with 13 additions and 1 deletions

View File

@@ -97,6 +97,8 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) {
} else if (reader_dtype == 5) {
// regard int64 as int32
return nvinfer1::DataType::kINT32;
} else if (reader_dtype == 6) {
return nvinfer1::DataType::kHALF;
}
FDASSERT(false, "Received unexpected data type of %d", reader_dtype);
return nvinfer1::DataType::kFLOAT;
@@ -135,4 +137,4 @@ nvinfer1::Dims ToDims(const std::vector<int64_t>& vec) {
return dims;
}
} // namespace fastdeploy
} // namespace fastdeploy

View File

@@ -36,6 +36,7 @@ void BindRuntime(pybind11::module& m) {
&RuntimeOption::SetPaddleMKLDNNCacheSize)
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
.def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize)
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)

View File

@@ -263,6 +263,10 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path;
}
void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) {
trt_max_workspace_size = max_workspace_size;
}
bool Runtime::Init(const RuntimeOption& _option) {
option = _option;
if (option.model_format == ModelFormat::AUTOREC) {

View File

@@ -106,6 +106,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
void SetTrtCacheFile(const std::string& cache_file_path);
void SetTrtMaxWorkspaceSize(size_t trt_max_workspace_size);
Backend backend = Backend::UNKNOWN;
// for cpu inference and preprocess
// default will let the backend choose their own default value

View File

@@ -125,6 +125,9 @@ class RuntimeOption:
def disable_trt_fp16(self):
return self._option.disable_trt_fp16()
def set_trt_max_workspace_size(self, trt_max_workspace_size):
return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
def __repr__(self):
attrs = dir(self._option)
message = "RuntimeOption(\n"