diff --git a/fastdeploy/backends/tensorrt/utils.cc b/fastdeploy/backends/tensorrt/utils.cc index 20c997ecd..1347b0a4a 100644 --- a/fastdeploy/backends/tensorrt/utils.cc +++ b/fastdeploy/backends/tensorrt/utils.cc @@ -97,6 +97,8 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) { } else if (reader_dtype == 5) { // regard int64 as int32 return nvinfer1::DataType::kINT32; + } else if (reader_dtype == 6) { + return nvinfer1::DataType::kHALF; } FDASSERT(false, "Received unexpected data type of %d", reader_dtype); return nvinfer1::DataType::kFLOAT; @@ -135,4 +137,4 @@ nvinfer1::Dims ToDims(const std::vector& vec) { return dims; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index ddf9a9585..44647aa29 100644 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -36,6 +36,7 @@ void BindRuntime(pybind11::module& m) { &RuntimeOption::SetPaddleMKLDNNCacheSize) .def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode) .def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape) + .def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize) .def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16) .def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16) .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 67974fefe..430c70474 100644 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -263,6 +263,10 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { trt_serialize_file = cache_file_path; } +void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) { + trt_max_workspace_size = max_workspace_size; +} + bool Runtime::Init(const RuntimeOption& _option) { option = _option; if (option.model_format == ModelFormat::AUTOREC) { diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index b6f0affd5..4f616ccf4 100644 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -106,6 +106,8 @@ struct FASTDEPLOY_DECL RuntimeOption { void SetTrtCacheFile(const std::string& cache_file_path); + void SetTrtMaxWorkspaceSize(size_t trt_max_workspace_size); + Backend backend = Backend::UNKNOWN; // for cpu inference and preprocess // default will let the backend choose their own default value diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 8fb5e8ec7..abd1d4bac 100644 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -125,6 +125,9 @@ class RuntimeOption: def disable_trt_fp16(self): return self._option.disable_trt_fp16() + def set_trt_max_workspace_size(self, trt_max_workspace_size): + return self._option.set_trt_max_workspace_size(trt_max_workspace_size) + def __repr__(self): attrs = dir(self._option) message = "RuntimeOption(\n"