diff --git a/csrcs/fastdeploy/backends/ort/ort_backend.cc b/csrcs/fastdeploy/backends/ort/ort_backend.cc index 9fdb3c66b..c17890109 100644 --- a/csrcs/fastdeploy/backends/ort/ort_backend.cc +++ b/csrcs/fastdeploy/backends/ort/ort_backend.cc @@ -26,35 +26,6 @@ namespace fastdeploy { std::vector OrtBackend::custom_operators_ = std::vector(); -ONNXTensorElementDataType GetOrtDtype(FDDataType fd_dtype) { - if (fd_dtype == FDDataType::FP32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } else if (fd_dtype == FDDataType::FP64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - } else if (fd_dtype == FDDataType::INT32) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - } else if (fd_dtype == FDDataType::INT64) { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype) - << "." << std::endl; - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; -} - -FDDataType GetFdDtype(ONNXTensorElementDataType ort_dtype) { - if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return FDDataType::FP32; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return FDDataType::FP64; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { - return FDDataType::INT32; - } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - return FDDataType::INT64; - } - FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl; - return FDDataType::FP32; -} - void OrtBackend::BuildOption(const OrtBackendOption& option) { option_ = option; if (option.graph_optimization_level >= 0) { @@ -263,6 +234,7 @@ bool OrtBackend::Infer(std::vector& inputs, (*outputs)[i].name = outputs_desc_[i].name; CopyToCpu(ort_outputs[i], &((*outputs)[i])); } + return true; } diff --git a/csrcs/fastdeploy/backends/ort/utils.cc b/csrcs/fastdeploy/backends/ort/utils.cc index bbef1f378..ae3e45b86 100644 --- a/csrcs/fastdeploy/backends/ort/utils.cc +++ b/csrcs/fastdeploy/backends/ort/utils.cc @@ -27,8 +27,8 @@ ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype) { } else if (fd_dtype == FDDataType::INT64) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } - FDERROR << "Unrecognized fastdeply data type:" << FDDataTypeStr(fd_dtype) - << "." << std::endl; + FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "." + << std::endl; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -64,4 +64,4 @@ Ort::Value CreateOrtValue(FDTensor& tensor, bool is_backend_cuda) { return ort_value; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/backends/ort/utils.h b/csrcs/fastdeploy/backends/ort/utils.h index b1b29e5ab..e2912ad38 100644 --- a/csrcs/fastdeploy/backends/ort/utils.h +++ b/csrcs/fastdeploy/backends/ort/utils.h @@ -20,7 +20,7 @@ #include #include "fastdeploy/backends/backend.h" -#include "onnxruntime_cxx_api.h" // NOLINT +#include "onnxruntime_cxx_api.h" // NOLINT namespace fastdeploy { @@ -28,7 +28,7 @@ namespace fastdeploy { ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype); // Convert OrtDataType to FDDataType -FDDataType GetFdDtype(const ONNXTensorElementDataType* ort_dtype); +FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype); // Create Ort::Value // is_backend_cuda specify if the onnxruntime use CUDAExectionProvider @@ -36,4 +36,4 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType* ort_dtype); // Will directly share the cuda data in tensor to OrtValue Ort::Value CreateOrtValue(FDTensor& tensor, bool is_backend_cuda = false); -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_tensor.cc b/csrcs/fastdeploy/core/fd_tensor.cc index 97b33dad5..dbefbd9ec 100644 --- a/csrcs/fastdeploy/core/fd_tensor.cc +++ b/csrcs/fastdeploy/core/fd_tensor.cc @@ -119,9 +119,9 @@ void FDTensor::PrintInfo(const std::string& prefix) { for (int i = 0; i < shape.size(); ++i) { std::cout << shape[i] << " "; } - std::cout << ", dtype=" << FDDataTypeStr(dtype) << ", mean=" << mean - << ", max=" << max << ", min=" << min << std::endl; + std::cout << ", dtype=" << Str(dtype) << ", mean=" << mean << ", max=" << max + << ", min=" << min << std::endl; } FDTensor::FDTensor(const std::string& tensor_name) { name = tensor_name; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_type.cc b/csrcs/fastdeploy/core/fd_type.cc index b66cabeb8..8d624cdf2 100644 --- a/csrcs/fastdeploy/core/fd_type.cc +++ b/csrcs/fastdeploy/core/fd_type.cc @@ -17,7 +17,7 @@ namespace fastdeploy { -int FDDataTypeSize(FDDataType data_type) { +int FDDataTypeSize(const FDDataType& data_type) { FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported."); if (data_type == FDDataType::BOOL) { return sizeof(bool); @@ -34,89 +34,63 @@ int FDDataTypeSize(FDDataType data_type) { } else if (data_type == FDDataType::UINT8) { return sizeof(uint8_t); } else { - FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type)); + FDASSERT(false, "Unexpected data type: " + Str(data_type)); } return -1; } -std::string FDDataTypeStr(FDDataType data_type) { - FDASSERT(data_type != FDDataType::FP16, "Float16 is not supported."); - if (data_type == FDDataType::BOOL) { - return "bool"; - } else if (data_type == FDDataType::INT16) { - return "int16"; - } else if (data_type == FDDataType::INT32) { - return "int32"; - } else if (data_type == FDDataType::INT64) { - return "int64"; - } else if (data_type == FDDataType::FP16) { - return "float16"; - } else if (data_type == FDDataType::FP32) { - return "float32"; - } else if (data_type == FDDataType::FP64) { - return "float64"; - } else if (data_type == FDDataType::UINT8) { - return "uint8"; - } else if (data_type == FDDataType::INT8) { - return "int8"; - } else { - FDASSERT(false, "Unexpected data type: " + FDDataTypeStr(data_type)); - } - return "UNKNOWN!"; -} - -std::string Str(Device& d) { +std::string Str(const Device& d) { std::string out; switch (d) { - case Device::DEFAULT: - out = "Device::DEFAULT"; - break; - case Device::CPU: - out = "Device::CPU"; - break; - case Device::GPU: - out = "Device::GPU"; - break; - default: - out = "Device::UNKOWN"; + case Device::DEFAULT: + out = "Device::DEFAULT"; + break; + case Device::CPU: + out = "Device::CPU"; + break; + case Device::GPU: + out = "Device::GPU"; + break; + default: + out = "Device::UNKOWN"; } return out; } -std::string Str(FDDataType& fdt) { +std::string Str(const FDDataType& fdt) { std::string out; switch (fdt) { - case FDDataType::BOOL: - out = "FDDataType::BOOL"; - break; - case FDDataType::INT16: - out = "FDDataType::INT16"; - break; - case FDDataType::INT32: - out = "FDDataType::INT32"; - break; - case FDDataType::INT64: - out = "FDDataType::INT64"; - break; - case FDDataType::FP32: - out = "FDDataType::FP32"; - break; - case FDDataType::FP64: - out = "FDDataType::FP64"; - break; - case FDDataType::FP16: - out = "FDDataType::FP16"; - break; - case FDDataType::UINT8: - out = "FDDataType::UINT8"; - break; - case FDDataType::INT8: - out = "FDDataType::INT8"; - break; - default: - out = "FDDataType::UNKNOWN"; + case FDDataType::BOOL: + out = "FDDataType::BOOL"; + break; + case FDDataType::INT16: + out = "FDDataType::INT16"; + break; + case FDDataType::INT32: + out = "FDDataType::INT32"; + break; + case FDDataType::INT64: + out = "FDDataType::INT64"; + break; + case FDDataType::FP32: + out = "FDDataType::FP32"; + break; + case FDDataType::FP64: + out = "FDDataType::FP64"; + break; + case FDDataType::FP16: + out = "FDDataType::FP16"; + break; + case FDDataType::UINT8: + out = "FDDataType::UINT8"; + break; + case FDDataType::INT8: + out = "FDDataType::INT8"; + break; + default: + out = "FDDataType::UNKNOWN"; } return out; } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/core/fd_type.h b/csrcs/fastdeploy/core/fd_type.h index 768ac1e36..325551dfb 100644 --- a/csrcs/fastdeploy/core/fd_type.h +++ b/csrcs/fastdeploy/core/fd_type.h @@ -24,7 +24,7 @@ namespace fastdeploy { enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU }; -FASTDEPLOY_DECL std::string Str(Device& d); +FASTDEPLOY_DECL std::string Str(const Device& d); enum FASTDEPLOY_DECL FDDataType { BOOL, @@ -51,9 +51,7 @@ enum FASTDEPLOY_DECL FDDataType { INT8 }; -FASTDEPLOY_DECL std::string Str(FDDataType& fdt); +FASTDEPLOY_DECL std::string Str(const FDDataType& fdt); -FASTDEPLOY_DECL int32_t FDDataTypeSize(FDDataType data_dtype); - -FASTDEPLOY_DECL std::string FDDataTypeStr(FDDataType data_dtype); +FASTDEPLOY_DECL int32_t FDDataTypeSize(const FDDataType& data_dtype); } // namespace fastdeploy diff --git a/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc b/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc index 3ede38040..412b1ccef 100644 --- a/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc +++ b/csrcs/fastdeploy/pybind/fastdeploy_runtime.cc @@ -79,6 +79,7 @@ void BindRuntime(pybind11::module& m) { memcpy(inputs[index].data.data(), iter->second.mutable_data(), iter->second.nbytes()); inputs[index].name = iter->first; + index += 1; } std::vector outputs(self.NumOutputs()); diff --git a/csrcs/fastdeploy/pybind/main.cc b/csrcs/fastdeploy/pybind/main.cc index 86467215e..e0c00c8a0 100644 --- a/csrcs/fastdeploy/pybind/main.cc +++ b/csrcs/fastdeploy/pybind/main.cc @@ -32,7 +32,7 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) { dt = pybind11::dtype::of(); } else { FDASSERT(false, "The function doesn't support data type of " + - FDDataTypeStr(fd_dtype) + "."); + Str(fd_dtype) + "."); } return dt; } @@ -47,8 +47,9 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) { } else if (np_dtype.is(pybind11::dtype::of())) { return FDDataType::FP64; } - FDASSERT(false, "NumpyDataTypeToFDDataType() only support " - "int32/int64/float32/float64 now."); + FDASSERT(false, + "NumpyDataTypeToFDDataType() only support " + "int32/int64/float32/float64 now."); return FDDataType::FP32; } @@ -112,4 +113,4 @@ PYBIND11_MODULE(fastdeploy_main, m) { #endif } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc index 9c698976e..5152db3fa 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.cc @@ -23,20 +23,31 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, initialized = Initialize(); } -bool PPYOLOE::Initialize() { -#ifdef ENABLE_PADDLE_FRONTEND - // remove multiclass_nms3 now - // this is a trick operation for ppyoloe while inference on trt +void PPYOLOE::GetNmsInfo() { if (runtime_option.model_format == Frontend::PADDLE) { std::string contents; if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { - return false; + return; } auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); if (reader.has_nms) { has_nms_ = true; + background_label = reader.nms_params.background_label; + keep_top_k = reader.nms_params.keep_top_k; + nms_eta = reader.nms_params.nms_eta; + nms_threshold = reader.nms_params.nms_threshold; + score_threshold = reader.nms_params.score_threshold; + nms_top_k = reader.nms_params.nms_top_k; + normalized = reader.nms_params.normalized; } } +} + +bool PPYOLOE::Initialize() { +#ifdef ENABLE_PADDLE_FRONTEND + // remove multiclass_nms3 now + // this is a trick operation for ppyoloe while inference on trt + GetNmsInfo(); runtime_option.remove_multiclass_nms_ = true; runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS"; #endif @@ -52,8 +63,12 @@ bool PPYOLOE::Initialize() { if (has_nms_ && runtime_option.backend == Backend::TRT) { FDINFO << "Detected operator multiclass_nms3 in your model, will replace " - "it with fastdeploy::backend::MultiClassNMS replace it." - << std::endl; + "it with fastdeploy::backend::MultiClassNMS(background_label=" + << background_label << ", keep_top_k=" << keep_top_k + << ", nms_eta=" << nms_eta << ", nms_threshold=" << nms_threshold + << ", score_threshold=" << score_threshold + << ", nms_top_k=" << nms_top_k << ", normalized=" << normalized + << ")." << std::endl; has_nms_ = false; } return true; diff --git a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h index ec22aa2ce..d86508fa1 100644 --- a/csrcs/fastdeploy/vision/ppdet/ppyoloe.h +++ b/csrcs/fastdeploy/vision/ppdet/ppyoloe.h @@ -42,6 +42,10 @@ class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { int64_t nms_top_k = 10000; bool normalized = true; bool has_nms_ = false; + + // This function will used to check if this model contains multiclass_nms + // and get parameters from the operator + void GetNmsInfo(); }; } // namespace ppdet } // namespace vision diff --git a/external/paddle2onnx.cmake b/external/paddle2onnx.cmake index e226bc6c9..ae6f4acda 100644 --- a/external/paddle2onnx.cmake +++ b/external/paddle2onnx.cmake @@ -43,7 +43,7 @@ else() endif(WIN32) set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/paddle2onnx/libs/") -set(PADDLE2ONNX_VERSION "1.0.0rc2") +set(PADDLE2ONNX_VERSION "1.0.0rc3") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") if(NOT CMAKE_CL_64) diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index f9b9f686e..6a23cd3d2 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -16,12 +16,14 @@ import logging import os import sys + def add_dll_search_dir(dir_path): os.environ["path"] = dir_path + ";" + os.environ["path"] sys.path.insert(0, dir_path) if sys.version_info[:2] >= (3, 8): os.add_dll_directory(dir_path) + if os.name == "nt": current_path = os.path.abspath(__file__) dirname = os.path.dirname(current_path) @@ -33,82 +35,19 @@ if os.name == "nt": add_dll_search_dir(os.path.join(dirname, root, d)) from .fastdeploy_main import Frontend, Backend, FDDataType, TensorInfo, Device -from .fastdeploy_runtime import * +from .runtime import Runtime, RuntimeOption +from .model import FastDeployModel from . import fastdeploy_main as C from . import vision from .download import download, download_and_decompress + def TensorInfoStr(tensor_info): message = "TensorInfo(name : '{}', dtype : '{}', shape : '{}')".format( tensor_info.name, tensor_info.dtype, tensor_info.shape) return message -class RuntimeOption: - def __init__(self): - self._option = C.RuntimeOption() - - def set_model_path(self, model_path, params_path="", model_format="paddle"): - return self._option.set_model_path(model_path, params_path, model_format) - - def use_gpu(self, device_id=0): - return self._option.use_gpu(device_id) - - def use_cpu(self): - return self._option.use_cpu() - - def set_cpu_thread_num(self, thread_num=8): - return self._option.set_cpu_thread_num(thread_num) - - def use_paddle_backend(self): - return self._option.use_paddle_backend() - - def use_ort_backend(self): - return self._option.use_ort_backend() - - def use_trt_backend(self): - return self._option.use_trt_backend() - - def enable_paddle_mkldnn(self): - return self._option.enable_paddle_mkldnn() - - def disable_paddle_mkldnn(self): - return self._option.disable_paddle_mkldnn() - - def set_paddle_mkldnn_cache_size(self, cache_size): - return self._option.set_paddle_mkldnn_cache_size(cache_size) - - def set_trt_input_shape(self, tensor_name, min_shape, opt_shape=None, max_shape=None): - if opt_shape is None and max_shape is None: - opt_shape = min_shape - max_shape = min_shape - else: - assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." - return self._option.set_trt_input_shape(tensor_name, min_shape, opt_shape, max_shape) - - def set_trt_cache_file(self, cache_file_path): - return self._option.set_trt_cache_file(cache_file_path) - - def enable_trt_fp16(self): - return self._option.enable_trt_fp16() - - def dissable_trt_fp16(self): - return self._option.disable_trt_fp16() - - def __repr__(self): - attrs = dir(self._option) - message = "RuntimeOption(\n" - for attr in attrs: - if attr.startswith("__"): - continue - if hasattr(getattr(self._option, attr), "__call__"): - continue - message += " {} : {}\t\n".format(attr, getattr(self._option, attr)) - message.strip("\n") - message += ")" - return message - - def RuntimeOptionStr(runtime_option): attrs = dir(runtime_option) message = "RuntimeOption(\n" @@ -122,5 +61,6 @@ def RuntimeOptionStr(runtime_option): message += ")" return message + C.TensorInfo.__repr__ = TensorInfoStr -C.RuntimeOption.__repr__ = RuntimeOptionStr \ No newline at end of file +C.RuntimeOption.__repr__ = RuntimeOptionStr diff --git a/fastdeploy/fastdeploy_runtime.py b/fastdeploy/model.py similarity index 61% rename from fastdeploy/fastdeploy_runtime.py rename to fastdeploy/model.py index e07e28993..f0faa1610 100644 --- a/fastdeploy/fastdeploy_runtime.py +++ b/fastdeploy/model.py @@ -54,35 +54,3 @@ class FastDeployModel: if self._model is None: return False return self._model.initialized() - - -class Runtime: - def __init__(self, runtime_option): - self._runtime = C.Runtime() - assert self._runtime.init(runtime_option), "Initialize Runtime Failed!" - - def infer(self, data): - assert isinstance(data, dict), "The input data should be type of dict." - return self._runtime.infer(data) - - def num_inputs(self): - return self._runtime.num_inputs() - - def num_outputs(self): - return self._runtime.num_outputs() - - def get_input_info(self, index): - assert isinstance( - index, int), "The input parameter index should be type of int." - assert index < self.num_inputs( - ), "The input parameter index:{} should less than number of inputs:{}.".format( - index, self.num_inputs) - return self._runtime.get_input_info(index) - - def get_output_info(self, index): - assert isinstance( - index, int), "The input parameter index should be type of int." - assert index < self.num_outputs( - ), "The input parameter index:{} should less than number of outputs:{}.".format( - index, self.num_outputs) - return self._runtime.get_output_info(index) diff --git a/fastdeploy/runtime.py b/fastdeploy/runtime.py new file mode 100644 index 000000000..a560f63a2 --- /dev/null +++ b/fastdeploy/runtime.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +import logging +from . import fastdeploy_main as C + + +class Runtime: + def __init__(self, runtime_option): + self._runtime = C.Runtime() + assert self._runtime.init(runtime_option), "Initialize Runtime Failed!" + + def infer(self, data): + assert isinstance(data, dict), "The input data should be type of dict." + return self._runtime.infer(data) + + def num_inputs(self): + return self._runtime.num_inputs() + + def num_outputs(self): + return self._runtime.num_outputs() + + def get_input_info(self, index): + assert isinstance( + index, int), "The input parameter index should be type of int." + assert index < self.num_inputs( + ), "The input parameter index:{} should less than number of inputs:{}.".format( + index, self.num_inputs) + return self._runtime.get_input_info(index) + + def get_output_info(self, index): + assert isinstance( + index, int), "The input parameter index should be type of int." + assert index < self.num_outputs( + ), "The input parameter index:{} should less than number of outputs:{}.".format( + index, self.num_outputs) + return self._runtime.get_output_info(index) + + +class RuntimeOption: + def __init__(self): + self._option = C.RuntimeOption() + + def set_model_path(self, model_path, params_path="", + model_format="paddle"): + return self._option.set_model_path(model_path, params_path, + model_format) + + def use_gpu(self, device_id=0): + return self._option.use_gpu(device_id) + + def use_cpu(self): + return self._option.use_cpu() + + def set_cpu_thread_num(self, thread_num=8): + return self._option.set_cpu_thread_num(thread_num) + + def use_paddle_backend(self): + return self._option.use_paddle_backend() + + def use_ort_backend(self): + return self._option.use_ort_backend() + + def use_trt_backend(self): + return self._option.use_trt_backend() + + def enable_paddle_mkldnn(self): + return self._option.enable_paddle_mkldnn() + + def disable_paddle_mkldnn(self): + return self._option.disable_paddle_mkldnn() + + def set_paddle_mkldnn_cache_size(self, cache_size): + return self._option.set_paddle_mkldnn_cache_size(cache_size) + + def set_trt_input_shape(self, + tensor_name, + min_shape, + opt_shape=None, + max_shape=None): + if opt_shape is None and max_shape is None: + opt_shape = min_shape + max_shape = min_shape + else: + assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." + return self._option.set_trt_input_shape(tensor_name, min_shape, + opt_shape, max_shape) + + def set_trt_cache_file(self, cache_file_path): + return self._option.set_trt_cache_file(cache_file_path) + + def enable_trt_fp16(self): + return self._option.enable_trt_fp16() + + def disable_trt_fp16(self): + return self._option.disable_trt_fp16() + + def __repr__(self): + attrs = dir(self._option) + message = "RuntimeOption(\n" + for attr in attrs: + if attr.startswith("__"): + continue + if hasattr(getattr(self._option, attr), "__call__"): + continue + message += " {} : {}\t\n".format(attr, + getattr(self._option, attr)) + message.strip("\n") + message += ")" + return message