mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Backend] Support Intel GPU with heterogeneous mode (#701)
* Add some comments for python api * support openvino gpu * Add cpu operators * add interface to specify hetero operators * remove useless dir * format code * remove debug code * Support GPU for ONNX
This commit is contained in:
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
|
||||||
"features": {
|
|
||||||
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
|
|
||||||
"version": "latest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -32,6 +32,14 @@ std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ov::PartialShape VecToPartialShape(const std::vector<int64_t>& shape) {
|
||||||
|
std::vector<ov::Dimension> dims;
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
dims.emplace_back(ov::Dimension(shape[i]));
|
||||||
|
}
|
||||||
|
return ov::PartialShape(dims);
|
||||||
|
}
|
||||||
|
|
||||||
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
|
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
|
||||||
if (type == ov::element::f32) {
|
if (type == ov::element::f32) {
|
||||||
return FDDataType::FP32;
|
return FDDataType::FP32;
|
||||||
@@ -100,6 +108,26 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
option_ = option;
|
option_ = option;
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
|
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
|
||||||
|
if (option_.shape_infos.size() > 0) {
|
||||||
|
std::map<std::string, ov::PartialShape> shape_infos;
|
||||||
|
for (const auto& item : option_.shape_infos) {
|
||||||
|
shape_infos[item.first] = VecToPartialShape(item.second);
|
||||||
|
}
|
||||||
|
model->reshape(shape_infos);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (option_.device.find("HETERO") != std::string::npos) {
|
||||||
|
auto supported_ops = core_.query_model(model, option_.device);
|
||||||
|
for (auto&& op : model->get_ops()) {
|
||||||
|
auto& affinity = supported_ops[op->get_friendly_name()];
|
||||||
|
if (option_.cpu_operators.find(op->description()) !=
|
||||||
|
option_.cpu_operators.end()) {
|
||||||
|
op->get_rt_info()["affinity"] = "CPU";
|
||||||
|
} else {
|
||||||
|
op->get_rt_info()["affinity"] = affinity;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get inputs/outputs information from loaded model
|
// Get inputs/outputs information from loaded model
|
||||||
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
||||||
@@ -151,14 +179,25 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
if (option_.cpu_thread_num > 0) {
|
if (option_.cpu_thread_num > 0) {
|
||||||
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
||||||
}
|
}
|
||||||
if (option_.ov_num_streams == -1) {
|
if (option_.device == "CPU") {
|
||||||
|
if (option_.num_streams == -1) {
|
||||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||||
} else if (option_.ov_num_streams == -2) {
|
} else if (option_.num_streams == -2) {
|
||||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||||
} else if (option_.ov_num_streams > 0) {
|
} else if (option_.num_streams > 0) {
|
||||||
properties["NUM_STREAMS"] = option_.ov_num_streams;
|
properties["NUM_STREAMS"] = option_.num_streams;
|
||||||
}
|
}
|
||||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
|
} else {
|
||||||
|
if (option_.num_streams != 0) {
|
||||||
|
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
|
||||||
|
"device is set as "
|
||||||
|
<< option_.device << ", the NUM_STREAMS will be ignored."
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
|
||||||
|
<< std::endl;
|
||||||
compiled_model_ = core_.compile_model(model, option.device, properties);
|
compiled_model_ = core_.compile_model(model, option.device, properties);
|
||||||
|
|
||||||
request_ = compiled_model_.create_infer_request();
|
request_ = compiled_model_.create_infer_request();
|
||||||
@@ -199,6 +238,27 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
|
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
|
||||||
|
|
||||||
|
if (option_.shape_infos.size() > 0) {
|
||||||
|
std::map<std::string, ov::PartialShape> shape_infos;
|
||||||
|
for (const auto& item : option_.shape_infos) {
|
||||||
|
shape_infos[item.first] = VecToPartialShape(item.second);
|
||||||
|
}
|
||||||
|
model->reshape(shape_infos);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (option_.device.find("HETERO") != std::string::npos) {
|
||||||
|
auto supported_ops = core_.query_model(model, option_.device);
|
||||||
|
for (auto&& op : model->get_ops()) {
|
||||||
|
auto& affinity = supported_ops[op->get_friendly_name()];
|
||||||
|
if (option_.cpu_operators.find(op->description()) !=
|
||||||
|
option_.cpu_operators.end()) {
|
||||||
|
op->get_rt_info()["affinity"] = "CPU";
|
||||||
|
} else {
|
||||||
|
op->get_rt_info()["affinity"] = affinity;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get inputs/outputs information from loaded model
|
// Get inputs/outputs information from loaded model
|
||||||
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
||||||
std::map<std::string, TensorInfo> input_infos;
|
std::map<std::string, TensorInfo> input_infos;
|
||||||
@@ -249,14 +309,25 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
if (option_.cpu_thread_num > 0) {
|
if (option_.cpu_thread_num > 0) {
|
||||||
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
||||||
}
|
}
|
||||||
if (option_.ov_num_streams == -1) {
|
if (option_.device == "CPU") {
|
||||||
|
if (option_.num_streams == -1) {
|
||||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||||
} else if (option_.ov_num_streams == -2) {
|
} else if (option_.num_streams == -2) {
|
||||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||||
} else if (option_.ov_num_streams > 0) {
|
} else if (option_.num_streams > 0) {
|
||||||
properties["NUM_STREAMS"] = option_.ov_num_streams;
|
properties["NUM_STREAMS"] = option_.num_streams;
|
||||||
}
|
}
|
||||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
|
} else {
|
||||||
|
if (option_.num_streams != 0) {
|
||||||
|
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
|
||||||
|
"device is set as "
|
||||||
|
<< option_.device << ", the NUM_STREAMS will be ignored."
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
|
||||||
|
<< std::endl;
|
||||||
compiled_model_ = core_.compile_model(model, option.device, properties);
|
compiled_model_ = core_.compile_model(model, option.device, properties);
|
||||||
|
|
||||||
request_ = compiled_model_.create_infer_request();
|
request_ = compiled_model_.create_infer_request();
|
||||||
@@ -302,13 +373,16 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void *stream, int device_id) {
|
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void* stream,
|
||||||
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<OpenVINOBackend>();
|
int device_id) {
|
||||||
|
std::unique_ptr<BaseBackend> new_backend =
|
||||||
|
utils::make_unique<OpenVINOBackend>();
|
||||||
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
|
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
|
||||||
casted_backend->option_ = option_;
|
casted_backend->option_ = option_;
|
||||||
casted_backend->request_ = compiled_model_.create_infer_request();
|
casted_backend->request_ = compiled_model_.create_infer_request();
|
||||||
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
|
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
|
||||||
casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end());
|
casted_backend->output_infos_.assign(output_infos_.begin(),
|
||||||
|
output_infos_.end());
|
||||||
return new_backend;
|
return new_backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -28,8 +28,9 @@ namespace fastdeploy {
|
|||||||
struct OpenVINOBackendOption {
|
struct OpenVINOBackendOption {
|
||||||
std::string device = "CPU";
|
std::string device = "CPU";
|
||||||
int cpu_thread_num = -1;
|
int cpu_thread_num = -1;
|
||||||
int ov_num_streams = 1;
|
int num_streams = 0;
|
||||||
std::map<std::string, std::vector<int64_t>> shape_infos;
|
std::map<std::string, std::vector<int64_t>> shape_infos;
|
||||||
|
std::set<std::string> cpu_operators{"MulticlassNms"};
|
||||||
};
|
};
|
||||||
|
|
||||||
class OpenVINOBackend : public BaseBackend {
|
class OpenVINOBackend : public BaseBackend {
|
||||||
@@ -38,12 +39,12 @@ class OpenVINOBackend : public BaseBackend {
|
|||||||
OpenVINOBackend() {}
|
OpenVINOBackend() {}
|
||||||
virtual ~OpenVINOBackend() = default;
|
virtual ~OpenVINOBackend() = default;
|
||||||
|
|
||||||
bool InitFromPaddle(
|
bool
|
||||||
const std::string& model_file, const std::string& params_file,
|
InitFromPaddle(const std::string& model_file, const std::string& params_file,
|
||||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||||
|
|
||||||
bool InitFromOnnx(
|
bool
|
||||||
const std::string& model_file,
|
InitFromOnnx(const std::string& model_file,
|
||||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||||
|
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs,
|
||||||
@@ -58,7 +59,7 @@ class OpenVINOBackend : public BaseBackend {
|
|||||||
std::vector<TensorInfo> GetInputInfos() override;
|
std::vector<TensorInfo> GetInputInfos() override;
|
||||||
std::vector<TensorInfo> GetOutputInfos() override;
|
std::vector<TensorInfo> GetOutputInfos() override;
|
||||||
|
|
||||||
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
|
std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
|
||||||
int device_id = -1) override;
|
int device_id = -1) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -71,4 +72,5 @@ class OpenVINOBackend : public BaseBackend {
|
|||||||
std::vector<TensorInfo> input_infos_;
|
std::vector<TensorInfo> input_infos_;
|
||||||
std::vector<TensorInfo> output_infos_;
|
std::vector<TensorInfo> output_infos_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -34,6 +34,8 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
||||||
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
||||||
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
|
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
|
||||||
|
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
|
||||||
|
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators)
|
||||||
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
||||||
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
||||||
.def("set_paddle_mkldnn_cache_size",
|
.def("set_paddle_mkldnn_cache_size",
|
||||||
|
@@ -646,7 +646,11 @@ void Runtime::CreateOpenVINOBackend() {
|
|||||||
auto ov_option = OpenVINOBackendOption();
|
auto ov_option = OpenVINOBackendOption();
|
||||||
ov_option.cpu_thread_num = option.cpu_thread_num;
|
ov_option.cpu_thread_num = option.cpu_thread_num;
|
||||||
ov_option.device = option.openvino_device;
|
ov_option.device = option.openvino_device;
|
||||||
ov_option.ov_num_streams = option.ov_num_streams;
|
ov_option.shape_infos = option.ov_shape_infos;
|
||||||
|
ov_option.num_streams = option.ov_num_streams;
|
||||||
|
for (const auto& op : option.ov_cpu_operators) {
|
||||||
|
ov_option.cpu_operators.insert(op);
|
||||||
|
}
|
||||||
FDASSERT(option.model_format == ModelFormat::PADDLE ||
|
FDASSERT(option.model_format == ModelFormat::PADDLE ||
|
||||||
option.model_format == ModelFormat::ONNX,
|
option.model_format == ModelFormat::ONNX,
|
||||||
"OpenVINOBackend only support model format of ModelFormat::PADDLE / "
|
"OpenVINOBackend only support model format of ModelFormat::PADDLE / "
|
||||||
|
@@ -173,6 +173,21 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
*/
|
*/
|
||||||
void SetOpenVINODevice(const std::string& name = "CPU");
|
void SetOpenVINODevice(const std::string& name = "CPU");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set shape info for OpenVINO
|
||||||
|
*/
|
||||||
|
void SetOpenVINOShapeInfo(
|
||||||
|
const std::map<std::string, std::vector<int64_t>>& shape_info) {
|
||||||
|
ov_shape_infos = shape_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief While use OpenVINO backend with intel GPU, use this interface to specify operators run on CPU
|
||||||
|
*/
|
||||||
|
void SetOpenVINOCpuOperators(const std::vector<std::string>& operators) {
|
||||||
|
ov_cpu_operators = operators;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Set optimzed model dir for Paddle Lite backend.
|
* @brief Set optimzed model dir for Paddle Lite backend.
|
||||||
*/
|
*/
|
||||||
@@ -349,9 +364,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
size_t trt_max_batch_size = 32;
|
size_t trt_max_batch_size = 32;
|
||||||
size_t trt_max_workspace_size = 1 << 30;
|
size_t trt_max_workspace_size = 1 << 30;
|
||||||
|
|
||||||
// ======Only for OpenVINO Backend======
|
|
||||||
std::string openvino_device = "CPU";
|
|
||||||
|
|
||||||
// ======Only for Poros Backend=======
|
// ======Only for Poros Backend=======
|
||||||
bool is_dynamic = false;
|
bool is_dynamic = false;
|
||||||
bool long_to_int = true;
|
bool long_to_int = true;
|
||||||
@@ -360,7 +372,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
std::string poros_file = "";
|
std::string poros_file = "";
|
||||||
|
|
||||||
// ======Only for OpenVINO Backend=======
|
// ======Only for OpenVINO Backend=======
|
||||||
int ov_num_streams = 1;
|
int ov_num_streams = 0;
|
||||||
|
std::string openvino_device = "CPU";
|
||||||
|
std::map<std::string, std::vector<int64_t>> ov_shape_infos;
|
||||||
|
std::vector<std::string> ov_cpu_operators;
|
||||||
|
|
||||||
// ======Only for RKNPU2 Backend=======
|
// ======Only for RKNPU2 Backend=======
|
||||||
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_
|
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_
|
||||||
|
@@ -35,7 +35,7 @@ class Runtime:
|
|||||||
self.runtime_option._option), "Initialize Runtime Failed!"
|
self.runtime_option._option), "Initialize Runtime Failed!"
|
||||||
|
|
||||||
def forward(self, *inputs):
|
def forward(self, *inputs):
|
||||||
"""Inference with input data for poros
|
"""[Only for Poros backend] Inference with input data for poros
|
||||||
|
|
||||||
:param data: (list[str : numpy.ndarray])The input data list
|
:param data: (list[str : numpy.ndarray])The input data list
|
||||||
:return list of numpy.ndarray
|
:return list of numpy.ndarray
|
||||||
@@ -60,7 +60,7 @@ class Runtime:
|
|||||||
return self._runtime.infer(data)
|
return self._runtime.infer(data)
|
||||||
|
|
||||||
def compile(self, warm_datas):
|
def compile(self, warm_datas):
|
||||||
"""compile with prewarm data for poros
|
"""[Only for Poros backend] compile with prewarm data for poros
|
||||||
|
|
||||||
:param data: (list[str : numpy.ndarray])The prewarm data list
|
:param data: (list[str : numpy.ndarray])The prewarm data list
|
||||||
:return TorchScript Model
|
:return TorchScript Model
|
||||||
@@ -122,6 +122,9 @@ class RuntimeOption:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""Initialize a FastDeploy RuntimeOption object.
|
||||||
|
"""
|
||||||
|
|
||||||
self._option = C.RuntimeOption()
|
self._option = C.RuntimeOption()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -210,8 +213,6 @@ class RuntimeOption:
|
|||||||
def use_rknpu2(self,
|
def use_rknpu2(self,
|
||||||
rknpu2_name=rknpu2.CpuName.RK3588,
|
rknpu2_name=rknpu2.CpuName.RK3588,
|
||||||
rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0):
|
rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0):
|
||||||
"""Inference with CPU
|
|
||||||
"""
|
|
||||||
return self._option.use_rknpu2(rknpu2_name, rknpu2_core)
|
return self._option.use_rknpu2(rknpu2_name, rknpu2_core)
|
||||||
|
|
||||||
def set_cpu_thread_num(self, thread_num=-1):
|
def set_cpu_thread_num(self, thread_num=-1):
|
||||||
@@ -222,6 +223,10 @@ class RuntimeOption:
|
|||||||
return self._option.set_cpu_thread_num(thread_num)
|
return self._option.set_cpu_thread_num(thread_num)
|
||||||
|
|
||||||
def set_ort_graph_opt_level(self, level=-1):
|
def set_ort_graph_opt_level(self, level=-1):
|
||||||
|
"""Set graph optimization level for ONNX Runtime backend
|
||||||
|
|
||||||
|
:param level: (int)Optimization level, -1 means the default setting
|
||||||
|
"""
|
||||||
return self._option.set_ort_graph_opt_level(level)
|
return self._option.set_ort_graph_opt_level(level)
|
||||||
|
|
||||||
def use_paddle_backend(self):
|
def use_paddle_backend(self):
|
||||||
@@ -274,6 +279,20 @@ class RuntimeOption:
|
|||||||
"""
|
"""
|
||||||
return self._option.set_openvino_device(name)
|
return self._option.set_openvino_device(name)
|
||||||
|
|
||||||
|
def set_openvino_shape_info(self, shape_info):
|
||||||
|
"""Set shape information of the models' inputs, used for GPU to fix the shape
|
||||||
|
|
||||||
|
:param shape_info: (dict{str, list of int})Shape information of model's inputs, e.g {"image": [1, 3, 640, 640], "scale_factor": [1, 2]}
|
||||||
|
"""
|
||||||
|
return self._option.set_openvino_shape_info(shape_info)
|
||||||
|
|
||||||
|
def set_openvino_cpu_operators(self, operators):
|
||||||
|
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
|
||||||
|
|
||||||
|
:param operators: (list of string)list of operators' name, e.g ["MulticlasNms"]
|
||||||
|
"""
|
||||||
|
return self._option.set_openvino_cpu_operators(operators)
|
||||||
|
|
||||||
def enable_paddle_log_info(self):
|
def enable_paddle_log_info(self):
|
||||||
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
|
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
|
||||||
"""
|
"""
|
||||||
@@ -367,9 +386,13 @@ class RuntimeOption:
|
|||||||
return self._option.set_trt_max_batch_size(trt_max_batch_size)
|
return self._option.set_trt_max_batch_size(trt_max_batch_size)
|
||||||
|
|
||||||
def enable_paddle_trt_collect_shape(self):
|
def enable_paddle_trt_collect_shape(self):
|
||||||
|
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
|
||||||
|
"""
|
||||||
return self._option.enable_paddle_trt_collect_shape()
|
return self._option.enable_paddle_trt_collect_shape()
|
||||||
|
|
||||||
def disable_paddle_trt_collect_shape(self):
|
def disable_paddle_trt_collect_shape(self):
|
||||||
|
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT
|
||||||
|
"""
|
||||||
return self._option.disable_paddle_trt_collect_shape()
|
return self._option.disable_paddle_trt_collect_shape()
|
||||||
|
|
||||||
def use_ipu(self,
|
def use_ipu(self,
|
||||||
|
Reference in New Issue
Block a user