mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Other] Deprecate some option api and parameters (#1243)
* Optimize Poros backend * fix error * Add more pybind * fix conflicts * add some deprecate notices * [Other] Deprecate some apis in RuntimeOption (#1240) * Deprecate more options * modify serving * Update option.h * fix tensorrt error * Update option_pybind.cc * Update option_pybind.cc * Fix error in serving * fix word spell error
This commit is contained in:
@@ -48,6 +48,8 @@ enum LitePowerMode {
|
||||
LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode
|
||||
};
|
||||
|
||||
/*! @brief Option object to configure Paddle Lite backend
|
||||
*/
|
||||
struct LiteBackendOption {
|
||||
/// Paddle Lite power mode for mobile device.
|
||||
int power_mode = 3;
|
||||
@@ -55,12 +57,20 @@ struct LiteBackendOption {
|
||||
int cpu_threads = 1;
|
||||
/// Enable use half precision
|
||||
bool enable_fp16 = false;
|
||||
/// Enable use int8 precision for quantized model
|
||||
bool enable_int8 = false;
|
||||
|
||||
/// Inference device, Paddle Lite support CPU/KUNLUNXIN/TIMVX/ASCEND
|
||||
Device device = Device::CPU;
|
||||
/// Index of inference device
|
||||
int device_id = 0;
|
||||
|
||||
// optimized model dir for CxxConfig
|
||||
int kunlunxin_l3_workspace_size = 0xfffc00;
|
||||
bool kunlunxin_locked = false;
|
||||
bool kunlunxin_autotune = true;
|
||||
std::string kunlunxin_autotune_file = "";
|
||||
std::string kunlunxin_precision = "int16";
|
||||
bool kunlunxin_adaptive_seqlen = false;
|
||||
bool kunlunxin_enable_multi_stream = false;
|
||||
|
||||
/// Optimized model dir for CxxConfig
|
||||
std::string optimized_model_dir = "";
|
||||
std::string nnadapter_subgraph_partition_config_path = "";
|
||||
std::string nnadapter_subgraph_partition_config_buffer = "";
|
||||
@@ -70,13 +80,5 @@ struct LiteBackendOption {
|
||||
std::map<std::string, std::vector<std::vector<int64_t>>>
|
||||
nnadapter_dynamic_shape_info = {{"", {{0}}}};
|
||||
std::vector<std::string> nnadapter_device_names = {};
|
||||
int device_id = 0;
|
||||
int kunlunxin_l3_workspace_size = 0xfffc00;
|
||||
bool kunlunxin_locked = false;
|
||||
bool kunlunxin_autotune = true;
|
||||
std::string kunlunxin_autotune_file = "";
|
||||
std::string kunlunxin_precision = "int16";
|
||||
bool kunlunxin_adaptive_seqlen = false;
|
||||
bool kunlunxin_enable_multi_stream = false;
|
||||
};
|
||||
} // namespace fastdeploy
|
@@ -23,7 +23,6 @@ void BindLiteOption(pybind11::module& m) {
|
||||
.def_readwrite("power_mode", &LiteBackendOption::power_mode)
|
||||
.def_readwrite("cpu_threads", &LiteBackendOption::cpu_threads)
|
||||
.def_readwrite("enable_fp16", &LiteBackendOption::enable_fp16)
|
||||
.def_readwrite("enable_int8", &LiteBackendOption::enable_int8)
|
||||
.def_readwrite("device", &LiteBackendOption::device)
|
||||
.def_readwrite("optimized_model_dir",
|
||||
&LiteBackendOption::optimized_model_dir)
|
||||
|
@@ -23,9 +23,13 @@
|
||||
#include <set>
|
||||
namespace fastdeploy {
|
||||
|
||||
/*! @brief Option object to configure OpenVINO backend
|
||||
*/
|
||||
struct OpenVINOBackendOption {
|
||||
std::string device = "CPU";
|
||||
int cpu_thread_num = -1;
|
||||
|
||||
/// Number of streams while use OpenVINO
|
||||
int num_streams = 0;
|
||||
|
||||
/**
|
||||
|
@@ -22,20 +22,30 @@
|
||||
#include <map>
|
||||
namespace fastdeploy {
|
||||
|
||||
/*! @brief Option object to configure ONNX Runtime backend
|
||||
*/
|
||||
struct OrtBackendOption {
|
||||
// -1 means default
|
||||
// 0: ORT_DISABLE_ALL
|
||||
// 1: ORT_ENABLE_BASIC
|
||||
// 2: ORT_ENABLE_EXTENDED
|
||||
// 99: ORT_ENABLE_ALL (enable some custom optimizations e.g bert)
|
||||
/*
|
||||
* @brief Level of graph optimization, -1: mean default(Enable all the optimization strategy)/0: disable all the optimization strategy/1: enable basic strategy/2:enable extend strategy/99: enable all
|
||||
*/
|
||||
int graph_optimization_level = -1;
|
||||
/*
|
||||
* @brief Number of threads to execute the operator, -1: default
|
||||
*/
|
||||
int intra_op_num_threads = -1;
|
||||
/*
|
||||
* @brief Number of threads to execute the graph, -1: default. This parameter only will bring effects while the `OrtBackendOption::execution_mode` set to 1.
|
||||
*/
|
||||
int inter_op_num_threads = -1;
|
||||
// 0: ORT_SEQUENTIAL
|
||||
// 1: ORT_PARALLEL
|
||||
/*
|
||||
* @brief Execution mode for the graph, -1: default(Sequential mode)/0: Sequential mode, execute the operators in graph one by one. /1: Parallel mode, execute the operators in graph parallelly.
|
||||
*/
|
||||
int execution_mode = -1;
|
||||
/// Inference device, OrtBackend supports CPU/GPU
|
||||
Device device = Device::CPU;
|
||||
/// Inference device id
|
||||
int device_id = 0;
|
||||
|
||||
void* external_stream_ = nullptr;
|
||||
};
|
||||
} // namespace fastdeploy
|
||||
|
@@ -22,6 +22,8 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
/*! @brief Option object to configure Poros backend
|
||||
*/
|
||||
struct PorosBackendOption {
|
||||
Device device = Device::CPU;
|
||||
int device_id = 0;
|
||||
|
@@ -21,23 +21,64 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
/*! @brief Option object to configure TensorRT backend
|
||||
*/
|
||||
struct TrtBackendOption {
|
||||
std::string model_file = ""; // Path of model file
|
||||
std::string params_file = ""; // Path of parameters file, can be empty
|
||||
|
||||
// format of input model
|
||||
ModelFormat model_format = ModelFormat::AUTOREC;
|
||||
|
||||
int gpu_id = 0;
|
||||
bool enable_fp16 = false;
|
||||
bool enable_int8 = false;
|
||||
/// `max_batch_size`, it's deprecated in TensorRT 8.x
|
||||
size_t max_batch_size = 32;
|
||||
|
||||
/// `max_workspace_size` for TensorRT
|
||||
size_t max_workspace_size = 1 << 30;
|
||||
|
||||
/*
|
||||
* @brief Enable half precison inference, on some device not support half precision, it will fallback to float32 mode
|
||||
*/
|
||||
bool enable_fp16 = false;
|
||||
|
||||
/** \brief Set shape range of input tensor for the model that contain dynamic input shape while using TensorRT backend
|
||||
*
|
||||
* \param[in] tensor_name The name of input for the model which is dynamic shape
|
||||
* \param[in] min The minimal shape for the input tensor
|
||||
* \param[in] opt The optimized shape for the input tensor, just set the most common shape, if set as default value, it will keep same with min_shape
|
||||
* \param[in] max The maximum shape for the input tensor, if set as default value, it will keep same with min_shape
|
||||
*/
|
||||
void SetShape(const std::string& tensor_name,
|
||||
const std::vector<int32_t>& min,
|
||||
const std::vector<int32_t>& opt,
|
||||
const std::vector<int32_t>& max) {
|
||||
min_shape[tensor_name].clear();
|
||||
max_shape[tensor_name].clear();
|
||||
opt_shape[tensor_name].clear();
|
||||
min_shape[tensor_name].assign(min.begin(), min.end());
|
||||
if (opt.size() == 0) {
|
||||
opt_shape[tensor_name].assign(min.begin(), min.end());
|
||||
} else {
|
||||
opt_shape[tensor_name].assign(opt.begin(), opt.end());
|
||||
}
|
||||
if (max.size() == 0) {
|
||||
max_shape[tensor_name].assign(min.begin(), min.end());
|
||||
} else {
|
||||
max_shape[tensor_name].assign(max.begin(), max.end());
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Set cache file path while use TensorRT backend. Loadding a Paddle/ONNX model and initialize TensorRT will take a long time, by this interface it will save the tensorrt engine to `cache_file_path`, and load it directly while execute the code again
|
||||
*/
|
||||
std::string serialize_file = "";
|
||||
|
||||
// The below parameters may be removed in next version, please do not
|
||||
// visit or use them directly
|
||||
std::map<std::string, std::vector<int32_t>> max_shape;
|
||||
std::map<std::string, std::vector<int32_t>> min_shape;
|
||||
std::map<std::string, std::vector<int32_t>> opt_shape;
|
||||
std::string serialize_file = "";
|
||||
bool enable_pinned_memory = false;
|
||||
void* external_stream_ = nullptr;
|
||||
int gpu_id = 0;
|
||||
std::string model_file = ""; // Path of model file
|
||||
std::string params_file = ""; // Path of parameters file, can be empty
|
||||
// format of input model
|
||||
ModelFormat model_format = ModelFormat::AUTOREC;
|
||||
};
|
||||
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
31
fastdeploy/runtime/backends/tensorrt/option_pybind.cc
Normal file
31
fastdeploy/runtime/backends/tensorrt/option_pybind.cc
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
#include "fastdeploy/runtime/backends/tensorrt/option.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindTrtOption(pybind11::module& m) {
|
||||
pybind11::class_<TrtBackendOption>(m, "TrtBackendOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("enable_fp16", &TrtBackendOption::enable_fp16)
|
||||
.def_readwrite("max_batch_size", &TrtBackendOption::max_batch_size)
|
||||
.def_readwrite("max_workspace_size",
|
||||
&TrtBackendOption::max_workspace_size)
|
||||
.def_readwrite("serialize_file", &TrtBackendOption::serialize_file)
|
||||
.def("set_shape", &TrtBackendOption::SetShape);
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
@@ -19,12 +19,14 @@ namespace fastdeploy {
|
||||
void BindLiteOption(pybind11::module& m);
|
||||
void BindOpenVINOOption(pybind11::module& m);
|
||||
void BindOrtOption(pybind11::module& m);
|
||||
void BindTrtOption(pybind11::module& m);
|
||||
void BindPorosOption(pybind11::module& m);
|
||||
|
||||
void BindOption(pybind11::module& m) {
|
||||
BindLiteOption(m);
|
||||
BindOpenVINOOption(m);
|
||||
BindOrtOption(m);
|
||||
BindTrtOption(m);
|
||||
BindPorosOption(m);
|
||||
|
||||
pybind11::class_<RuntimeOption>(m, "RuntimeOption")
|
||||
@@ -40,47 +42,22 @@ void BindOption(pybind11::module& m) {
|
||||
.def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option)
|
||||
.def_readwrite("openvino_option", &RuntimeOption::openvino_option)
|
||||
.def_readwrite("ort_option", &RuntimeOption::ort_option)
|
||||
.def_readwrite("trt_option", &RuntimeOption::trt_option)
|
||||
.def_readwrite("poros_option", &RuntimeOption::poros_option)
|
||||
.def("set_external_stream", &RuntimeOption::SetExternalStream)
|
||||
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
|
||||
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
|
||||
.def("use_poros_backend", &RuntimeOption::UsePorosBackend)
|
||||
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
|
||||
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
|
||||
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
|
||||
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
|
||||
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
||||
.def("set_lite_device_names", &RuntimeOption::SetLiteDeviceNames)
|
||||
.def("set_lite_context_properties",
|
||||
&RuntimeOption::SetLiteContextProperties)
|
||||
.def("set_lite_model_cache_dir", &RuntimeOption::SetLiteModelCacheDir)
|
||||
.def("set_lite_dynamic_shape_info",
|
||||
&RuntimeOption::SetLiteDynamicShapeInfo)
|
||||
.def("set_lite_subgraph_partition_path",
|
||||
&RuntimeOption::SetLiteSubgraphPartitionPath)
|
||||
.def("set_lite_mixed_precision_quantization_config_path",
|
||||
&RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath)
|
||||
.def("set_lite_subgraph_partition_config_buffer",
|
||||
&RuntimeOption::SetLiteSubgraphPartitionConfigBuffer)
|
||||
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
||||
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
|
||||
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
|
||||
.def("set_openvino_cpu_operators",
|
||||
&RuntimeOption::SetOpenVINOCpuOperators)
|
||||
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
||||
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
||||
.def("set_paddle_mkldnn_cache_size",
|
||||
&RuntimeOption::SetPaddleMKLDNNCacheSize)
|
||||
.def("enable_lite_fp16", &RuntimeOption::EnableLiteFP16)
|
||||
.def("disable_lite_fp16", &RuntimeOption::DisableLiteFP16)
|
||||
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
|
||||
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
|
||||
.def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize)
|
||||
.def("set_trt_max_batch_size", &RuntimeOption::SetTrtMaxBatchSize)
|
||||
.def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt)
|
||||
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
||||
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
||||
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
||||
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
|
||||
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
|
||||
.def("enable_paddle_trt_collect_shape",
|
||||
@@ -103,15 +80,6 @@ void BindOption(pybind11::module& m) {
|
||||
.def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num)
|
||||
.def_readwrite("device_id", &RuntimeOption::device_id)
|
||||
.def_readwrite("device", &RuntimeOption::device)
|
||||
.def_readwrite("trt_max_shape", &RuntimeOption::trt_max_shape)
|
||||
.def_readwrite("trt_opt_shape", &RuntimeOption::trt_opt_shape)
|
||||
.def_readwrite("trt_min_shape", &RuntimeOption::trt_min_shape)
|
||||
.def_readwrite("trt_serialize_file", &RuntimeOption::trt_serialize_file)
|
||||
.def_readwrite("trt_enable_fp16", &RuntimeOption::trt_enable_fp16)
|
||||
.def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8)
|
||||
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
|
||||
.def_readwrite("trt_max_workspace_size",
|
||||
&RuntimeOption::trt_max_workspace_size)
|
||||
.def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num)
|
||||
.def_readwrite("ipu_micro_batch_size",
|
||||
&RuntimeOption::ipu_micro_batch_size)
|
||||
|
@@ -244,17 +244,9 @@ void Runtime::CreatePaddleBackend() {
|
||||
if (pd_option.use_gpu && option.pd_enable_trt) {
|
||||
pd_option.enable_trt = true;
|
||||
pd_option.collect_shape = option.pd_collect_shape;
|
||||
auto trt_option = TrtBackendOption();
|
||||
trt_option.gpu_id = option.device_id;
|
||||
trt_option.enable_fp16 = option.trt_enable_fp16;
|
||||
trt_option.max_batch_size = option.trt_max_batch_size;
|
||||
trt_option.max_workspace_size = option.trt_max_workspace_size;
|
||||
trt_option.max_shape = option.trt_max_shape;
|
||||
trt_option.min_shape = option.trt_min_shape;
|
||||
trt_option.opt_shape = option.trt_opt_shape;
|
||||
trt_option.serialize_file = option.trt_serialize_file;
|
||||
trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||
pd_option.trt_option = trt_option;
|
||||
pd_option.trt_option = option.trt_option;
|
||||
pd_option.trt_option.gpu_id = option.device_id;
|
||||
pd_option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
|
||||
}
|
||||
#endif
|
||||
@@ -339,41 +331,33 @@ void Runtime::CreateTrtBackend() {
|
||||
"TrtBackend only support model format of ModelFormat::PADDLE / "
|
||||
"ModelFormat::ONNX.");
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
auto trt_option = TrtBackendOption();
|
||||
trt_option.model_file = option.model_file;
|
||||
trt_option.params_file = option.params_file;
|
||||
trt_option.model_format = option.model_format;
|
||||
trt_option.gpu_id = option.device_id;
|
||||
trt_option.enable_fp16 = option.trt_enable_fp16;
|
||||
trt_option.enable_int8 = option.trt_enable_int8;
|
||||
trt_option.max_batch_size = option.trt_max_batch_size;
|
||||
trt_option.max_workspace_size = option.trt_max_workspace_size;
|
||||
trt_option.max_shape = option.trt_max_shape;
|
||||
trt_option.min_shape = option.trt_min_shape;
|
||||
trt_option.opt_shape = option.trt_opt_shape;
|
||||
trt_option.serialize_file = option.trt_serialize_file;
|
||||
trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||
trt_option.external_stream_ = option.external_stream_;
|
||||
option.trt_option.model_file = option.model_file;
|
||||
option.trt_option.params_file = option.params_file;
|
||||
option.trt_option.model_format = option.model_format;
|
||||
option.trt_option.gpu_id = option.device_id;
|
||||
option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||
option.trt_option.external_stream_ = option.external_stream_;
|
||||
backend_ = utils::make_unique<TrtBackend>();
|
||||
auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get());
|
||||
casted_backend->benchmark_option_ = option.benchmark_option;
|
||||
|
||||
if (option.model_format == ModelFormat::ONNX) {
|
||||
if (option.model_from_memory_) {
|
||||
FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option),
|
||||
FDASSERT(
|
||||
casted_backend->InitFromOnnx(option.model_file, option.trt_option),
|
||||
"Load model from ONNX failed while initliazing TrtBackend.");
|
||||
ReleaseModelMemoryBuffer();
|
||||
} else {
|
||||
std::string model_buffer = "";
|
||||
FDASSERT(ReadBinaryFromFile(option.model_file, &model_buffer),
|
||||
"Fail to read binary from model file");
|
||||
FDASSERT(casted_backend->InitFromOnnx(model_buffer, trt_option),
|
||||
FDASSERT(casted_backend->InitFromOnnx(model_buffer, option.trt_option),
|
||||
"Load model from ONNX failed while initliazing TrtBackend.");
|
||||
}
|
||||
} else {
|
||||
if (option.model_from_memory_) {
|
||||
FDASSERT(casted_backend->InitFromPaddle(option.model_file,
|
||||
option.params_file, trt_option),
|
||||
FDASSERT(casted_backend->InitFromPaddle(
|
||||
option.model_file, option.params_file, option.trt_option),
|
||||
"Load model from Paddle failed while initliazing TrtBackend.");
|
||||
ReleaseModelMemoryBuffer();
|
||||
} else {
|
||||
@@ -384,7 +368,7 @@ void Runtime::CreateTrtBackend() {
|
||||
FDASSERT(ReadBinaryFromFile(option.params_file, ¶ms_buffer),
|
||||
"Fail to read binary from parameter file");
|
||||
FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer,
|
||||
trt_option),
|
||||
option.trt_option),
|
||||
"Load model from Paddle failed while initliazing TrtBackend.");
|
||||
}
|
||||
}
|
||||
@@ -505,9 +489,10 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
}
|
||||
option.poros_option.device = option.device;
|
||||
option.poros_option.device_id = option.device_id;
|
||||
option.poros_option.enable_fp16 = option.trt_enable_fp16;
|
||||
option.poros_option.max_batch_size = option.trt_max_batch_size;
|
||||
option.poros_option.max_workspace_size = option.trt_max_workspace_size;
|
||||
option.poros_option.enable_fp16 = option.trt_option.enable_fp16;
|
||||
option.poros_option.max_batch_size = option.trt_option.max_batch_size;
|
||||
option.poros_option.max_workspace_size = option.trt_option.max_workspace_size;
|
||||
|
||||
backend_ = utils::make_unique<PorosBackend>();
|
||||
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
|
||||
FDASSERT(
|
||||
|
@@ -102,6 +102,10 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
|
||||
}
|
||||
|
||||
void RuntimeOption::SetOrtGraphOptLevel(int level) {
|
||||
FDWARNING << "`RuntimeOption::SetOrtGraphOptLevel` will be removed in "
|
||||
"v1.2.0, please modify its member variables directly, e.g "
|
||||
"`runtime_option.ort_option.graph_optimization_level = 99`."
|
||||
<< std::endl;
|
||||
std::vector<int> supported_level{-1, 0, 1, 2};
|
||||
auto valid_level = std::find(supported_level.begin(), supported_level.end(),
|
||||
level) != supported_level.end();
|
||||
@@ -203,67 +207,127 @@ void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) {
|
||||
}
|
||||
|
||||
void RuntimeOption::SetOpenVINODevice(const std::string& name) {
|
||||
openvino_option.device = name;
|
||||
FDWARNING << "`RuntimeOption::SetOpenVINODevice` will be removed in v1.2.0, "
|
||||
"please use `RuntimeOption.openvino_option.SetDeivce(const "
|
||||
"std::string&)` instead."
|
||||
<< std::endl;
|
||||
openvino_option.SetDevice(name);
|
||||
}
|
||||
|
||||
void RuntimeOption::EnableLiteFP16() { paddle_lite_option.enable_fp16 = true; }
|
||||
void RuntimeOption::EnableLiteFP16() {
|
||||
FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, "
|
||||
"please modify its member variables directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.enable_fp16 = true`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.enable_fp16 = true;
|
||||
}
|
||||
|
||||
void RuntimeOption::DisableLiteFP16() {
|
||||
FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, "
|
||||
"please modify its member variables directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.enable_fp16 = false`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.enable_fp16 = false;
|
||||
}
|
||||
|
||||
void RuntimeOption::EnableLiteInt8() { paddle_lite_option.enable_int8 = true; }
|
||||
void RuntimeOption::EnableLiteInt8() {
|
||||
FDWARNING << "RuntimeOption::EnableLiteInt8 is a useless api, this calling "
|
||||
"will not bring any effects, and will be removed in v1.2.0. if "
|
||||
"you load a quantized model, it will automatically run with "
|
||||
"int8 mode; otherwise it will run with float mode."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void RuntimeOption::DisableLiteInt8() {
|
||||
paddle_lite_option.enable_int8 = false;
|
||||
FDWARNING << "RuntimeOption::DisableLiteInt8 is a useless api, this calling "
|
||||
"will not bring any effects, and will be removed in v1.2.0. if "
|
||||
"you load a quantized model, it will automatically run with "
|
||||
"int8 mode; otherwise it will run with float mode."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLitePowerMode(LitePowerMode mode) {
|
||||
FDWARNING << "`RuntimeOption::SetLitePowerMode` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.power_mode = 3;`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.power_mode = mode;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteOptimizedModelDir(
|
||||
const std::string& optimized_model_dir) {
|
||||
FDWARNING
|
||||
<< "`RuntimeOption::SetLiteOptimizedModelDir` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.optimized_model_dir = \"...\"`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.optimized_model_dir = optimized_model_dir;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteSubgraphPartitionPath(
|
||||
const std::string& nnadapter_subgraph_partition_config_path) {
|
||||
FDWARNING << "`RuntimeOption::SetLiteSubgraphPartitionPath` will be removed "
|
||||
"in v1.2.0, please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.nnadapter_subgraph_"
|
||||
"partition_config_path = \"...\";` "
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_subgraph_partition_config_path =
|
||||
nnadapter_subgraph_partition_config_path;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteSubgraphPartitionConfigBuffer(
|
||||
const std::string& nnadapter_subgraph_partition_config_buffer) {
|
||||
FDWARNING
|
||||
<< "`RuntimeOption::SetLiteSubgraphPartitionConfigBuffer` will be "
|
||||
"removed in v1.2.0, please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.nnadapter_subgraph_partition_"
|
||||
"config_buffer = ...`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_subgraph_partition_config_buffer =
|
||||
nnadapter_subgraph_partition_config_buffer;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteDeviceNames(
|
||||
const std::vector<std::string>& nnadapter_device_names) {
|
||||
paddle_lite_option.nnadapter_device_names = nnadapter_device_names;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteContextProperties(
|
||||
const std::string& nnadapter_context_properties) {
|
||||
FDWARNING << "`RuntimeOption::SetLiteContextProperties` will be removed in "
|
||||
"v1.2.0, please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.nnadapter_context_"
|
||||
"properties = ...`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_context_properties =
|
||||
nnadapter_context_properties;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteModelCacheDir(
|
||||
const std::string& nnadapter_model_cache_dir) {
|
||||
FDWARNING
|
||||
<< "`RuntimeOption::SetLiteModelCacheDir` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.nnadapter_model_cache_dir = ...`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_model_cache_dir = nnadapter_model_cache_dir;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteDynamicShapeInfo(
|
||||
const std::map<std::string, std::vector<std::vector<int64_t>>>&
|
||||
nnadapter_dynamic_shape_info) {
|
||||
FDWARNING << "`RuntimeOption::SetLiteDynamicShapeInfo` will be removed in "
|
||||
"v1.2.0, please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.paddle_lite_option."
|
||||
"nnadapter_dynamic_shape_info = ...`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_dynamic_shape_info =
|
||||
nnadapter_dynamic_shape_info;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath(
|
||||
const std::string& nnadapter_mixed_precision_quantization_config_path) {
|
||||
FDWARNING
|
||||
<< "`RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath` will be "
|
||||
"removed in v1.2.0, please modify its member variable directly, e.g "
|
||||
"`runtime_option.paddle_lite_option.paddle_lite_option.nnadapter_"
|
||||
"mixed_precision_quantization_config_path = ...`"
|
||||
<< std::endl;
|
||||
paddle_lite_option.nnadapter_mixed_precision_quantization_config_path =
|
||||
nnadapter_mixed_precision_quantization_config_path;
|
||||
}
|
||||
@@ -272,42 +336,60 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name,
|
||||
const std::vector<int32_t>& min_shape,
|
||||
const std::vector<int32_t>& opt_shape,
|
||||
const std::vector<int32_t>& max_shape) {
|
||||
trt_min_shape[input_name].clear();
|
||||
trt_max_shape[input_name].clear();
|
||||
trt_opt_shape[input_name].clear();
|
||||
trt_min_shape[input_name].assign(min_shape.begin(), min_shape.end());
|
||||
if (opt_shape.size() == 0) {
|
||||
trt_opt_shape[input_name].assign(min_shape.begin(), min_shape.end());
|
||||
} else {
|
||||
trt_opt_shape[input_name].assign(opt_shape.begin(), opt_shape.end());
|
||||
}
|
||||
if (max_shape.size() == 0) {
|
||||
trt_max_shape[input_name].assign(min_shape.begin(), min_shape.end());
|
||||
} else {
|
||||
trt_max_shape[input_name].assign(max_shape.begin(), max_shape.end());
|
||||
}
|
||||
FDWARNING << "`RuntimeOption::SetTrtInputShape` will be removed in v1.2.0, "
|
||||
"please use `RuntimeOption.trt_option.SetShape()` instead."
|
||||
<< std::endl;
|
||||
trt_option.SetShape(input_name, min_shape, opt_shape, max_shape);
|
||||
}
|
||||
|
||||
void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) {
|
||||
trt_max_workspace_size = max_workspace_size;
|
||||
FDWARNING << "`RuntimeOption::SetTrtMaxWorkspaceSize` will be removed in "
|
||||
"v1.2.0, please modify its member variable directly, e.g "
|
||||
"`RuntimeOption.trt_option.max_workspace_size = "
|
||||
<< max_workspace_size << "`." << std::endl;
|
||||
trt_option.max_workspace_size = max_workspace_size;
|
||||
}
|
||||
void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) {
|
||||
trt_max_batch_size = max_batch_size;
|
||||
FDWARNING << "`RuntimeOption::SetTrtMaxBatchSize` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`RuntimeOption.trt_option.max_batch_size = "
|
||||
<< max_batch_size << "`." << std::endl;
|
||||
trt_option.max_batch_size = max_batch_size;
|
||||
}
|
||||
|
||||
void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }
|
||||
void RuntimeOption::EnableTrtFP16() {
|
||||
FDWARNING << "`RuntimeOption::EnableTrtFP16` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.trt_option.enable_fp16 = true;`"
|
||||
<< std::endl;
|
||||
trt_option.enable_fp16 = true;
|
||||
}
|
||||
|
||||
void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; }
|
||||
void RuntimeOption::DisableTrtFP16() {
|
||||
FDWARNING << "`RuntimeOption::DisableTrtFP16` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.trt_option.enable_fp16 = false;`"
|
||||
<< std::endl;
|
||||
trt_option.enable_fp16 = false;
|
||||
}
|
||||
|
||||
void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; }
|
||||
|
||||
void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; }
|
||||
|
||||
void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
|
||||
trt_serialize_file = cache_file_path;
|
||||
FDWARNING << "`RuntimeOption::SetTrtCacheFile` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.trt_option.serialize_file = \""
|
||||
<< cache_file_path << "\"." << std::endl;
|
||||
trt_option.serialize_file = cache_file_path;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetOpenVINOStreams(int num_streams) {
|
||||
FDWARNING << "`RuntimeOption::SetOpenVINOStreams` will be removed in v1.2.0, "
|
||||
"please modify its member variable directly, e.g "
|
||||
"`runtime_option.openvino_option.num_streams = "
|
||||
<< num_streams << "`." << std::endl;
|
||||
openvino_option.num_streams = num_streams;
|
||||
}
|
||||
|
||||
|
@@ -211,12 +211,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
void SetLiteSubgraphPartitionConfigBuffer(
|
||||
const std::string& nnadapter_subgraph_partition_config_buffer);
|
||||
|
||||
/**
|
||||
* @brief Set device name for Paddle Lite backend.
|
||||
*/
|
||||
void
|
||||
SetLiteDeviceNames(const std::vector<std::string>& nnadapter_device_names);
|
||||
|
||||
/**
|
||||
* @brief Set context properties for Paddle Lite backend.
|
||||
*/
|
||||
@@ -381,6 +375,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
|
||||
bool enable_pinned_memory = false;
|
||||
|
||||
/// Option to configure ONNX Runtime backend
|
||||
OrtBackendOption ort_option;
|
||||
|
||||
// ======Only for Paddle Backend=====
|
||||
@@ -401,20 +396,16 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
float ipu_available_memory_proportion = 1.0;
|
||||
bool ipu_enable_half_partial = false;
|
||||
|
||||
// ======Only for Trt Backend=======
|
||||
std::map<std::string, std::vector<int32_t>> trt_max_shape;
|
||||
std::map<std::string, std::vector<int32_t>> trt_min_shape;
|
||||
std::map<std::string, std::vector<int32_t>> trt_opt_shape;
|
||||
std::string trt_serialize_file = "";
|
||||
bool trt_enable_fp16 = false;
|
||||
bool trt_enable_int8 = false;
|
||||
size_t trt_max_batch_size = 1;
|
||||
size_t trt_max_workspace_size = 1 << 30;
|
||||
/// Option to configure TensorRT backend
|
||||
TrtBackendOption trt_option;
|
||||
|
||||
// ======Only for PaddleTrt Backend=======
|
||||
std::vector<std::string> trt_disabled_ops_{};
|
||||
|
||||
/// Option to configure Poros backend
|
||||
PorosBackendOption poros_option;
|
||||
|
||||
/// Option to configure OpenVINO backend
|
||||
OpenVINOBackendOption openvino_option;
|
||||
|
||||
// ======Only for RKNPU2 Backend=======
|
||||
@@ -433,10 +424,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
std::string model_file = "";
|
||||
std::string params_file = "";
|
||||
bool model_from_memory_ = false;
|
||||
// format of input model
|
||||
/// format of input model
|
||||
ModelFormat model_format = ModelFormat::PADDLE;
|
||||
|
||||
// Benchmark option
|
||||
/// Benchmark option
|
||||
benchmark::BenchmarkOption benchmark_option;
|
||||
};
|
||||
|
||||
|
@@ -154,6 +154,8 @@ class RuntimeOption:
|
||||
"""Options for FastDeploy Runtime.
|
||||
"""
|
||||
|
||||
__slots__ = ["_option"]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a FastDeploy RuntimeOption object.
|
||||
"""
|
||||
@@ -266,7 +268,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_ort_graph_opt_level` will be deprecated in v1.2.0, please use `RuntimeOption.graph_optimize_level = 99` instead."
|
||||
)
|
||||
return self._option.set_ort_graph_opt_level(level)
|
||||
self._option.ort_option.graph_optimize_level = level
|
||||
|
||||
def use_paddle_backend(self):
|
||||
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
|
||||
@@ -314,7 +316,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_context_properties` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_context_properties = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_context_properties(context_properties)
|
||||
self._option.paddle_lite_option.nnadapter_context_properties = context_properties
|
||||
|
||||
def set_lite_model_cache_dir(self, model_cache_dir):
|
||||
"""Set nnadapter model cache dir for Paddle Lite backend.
|
||||
@@ -322,7 +324,8 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_model_cache_dir` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_model_cache_dir = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_model_cache_dir(model_cache_dir)
|
||||
|
||||
self._option.paddle_lite_option.nnadapter_model_cache_dir = model_cache_dir
|
||||
|
||||
def set_lite_dynamic_shape_info(self, dynamic_shape_info):
|
||||
""" Set nnadapter dynamic shape info for Paddle Lite backend.
|
||||
@@ -330,7 +333,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_dynamic_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_dynamic_shape_info = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_dynamic_shape_info(dynamic_shape_info)
|
||||
self._option.paddle_lite_option.nnadapter_dynamic_shape_info = dynamic_shape_info
|
||||
|
||||
def set_lite_subgraph_partition_path(self, subgraph_partition_path):
|
||||
""" Set nnadapter subgraph partition path for Paddle Lite backend.
|
||||
@@ -338,8 +341,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_subgraph_partition_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_path = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_subgraph_partition_path(
|
||||
subgraph_partition_path)
|
||||
self._option.paddle_lite_option.nnadapter_subgraph_partition_config_path = subgraph_partition_path
|
||||
|
||||
def set_lite_subgraph_partition_config_buffer(self,
|
||||
subgraph_partition_buffer):
|
||||
@@ -348,8 +350,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_subgraph_partition_buffer` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_subgraph_partition_config_buffer(
|
||||
subgraph_partition_buffer)
|
||||
self._option.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = subgraph_partition_buffer
|
||||
|
||||
def set_lite_mixed_precision_quantization_config_path(
|
||||
self, mixed_precision_quantization_config_path):
|
||||
@@ -358,8 +359,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_mixed_precision_quantization_config_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = ...` instead."
|
||||
)
|
||||
return self._option.set_lite_mixed_precision_quantization_config_path(
|
||||
mixed_precision_quantization_config_path)
|
||||
self._option.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = mixed_precision_quantization_config_path
|
||||
|
||||
def set_paddle_mkldnn(self, use_mkldnn=True):
|
||||
"""Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default.
|
||||
@@ -373,7 +373,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_openvino_device` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_device` instead."
|
||||
)
|
||||
return self._option.set_openvino_device(name)
|
||||
self._option.openvino_option.set_device(name)
|
||||
|
||||
def set_openvino_shape_info(self, shape_info):
|
||||
"""Set shape information of the models' inputs, used for GPU to fix the shape
|
||||
@@ -384,7 +384,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_openvino_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_shape_info` instead."
|
||||
)
|
||||
return self._option.set_openvino_shape_info(shape_info)
|
||||
self._option.openvino_option.set_shape_info(shape_info)
|
||||
|
||||
def set_openvino_cpu_operators(self, operators):
|
||||
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
|
||||
@@ -395,7 +395,7 @@ class RuntimeOption:
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_openvino_cpu_operators` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_cpu_operators` instead."
|
||||
)
|
||||
return self._option.set_openvino_cpu_operators(operators)
|
||||
self._option.openvino_option.set_cpu_operators(operators)
|
||||
|
||||
def enable_paddle_log_info(self):
|
||||
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
|
||||
@@ -415,17 +415,26 @@ class RuntimeOption:
|
||||
def enable_lite_fp16(self):
|
||||
"""Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
|
||||
"""
|
||||
return self._option.enable_lite_fp16()
|
||||
logging.warning(
|
||||
"`RuntimeOption.enable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = True` instead."
|
||||
)
|
||||
self._option.paddle_lite_option.enable_fp16 = True
|
||||
|
||||
def disable_lite_fp16(self):
|
||||
"""Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
|
||||
"""
|
||||
return self._option.disable_lite_fp16()
|
||||
logging.warning(
|
||||
"`RuntimeOption.disable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = False` instead."
|
||||
)
|
||||
self._option.paddle_lite_option.enable_fp16 = False
|
||||
|
||||
def set_lite_power_mode(self, mode):
|
||||
"""Set POWER mode while using Paddle Lite backend on ARM CPU.
|
||||
"""
|
||||
return self._option.set_lite_power_mode(mode)
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_lite_powermode` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.power_mode = {}` instead.".
|
||||
format(mode))
|
||||
self._option.paddle_lite_option.power_mode = mode
|
||||
|
||||
def set_trt_input_shape(self,
|
||||
tensor_name,
|
||||
@@ -439,12 +448,15 @@ class RuntimeOption:
|
||||
:param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape
|
||||
:param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape
|
||||
"""
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_trt_input_shape` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.set_shape()` instead."
|
||||
)
|
||||
if opt_shape is None and max_shape is None:
|
||||
opt_shape = min_shape
|
||||
max_shape = min_shape
|
||||
else:
|
||||
assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both."
|
||||
return self._option.set_trt_input_shape(tensor_name, min_shape,
|
||||
return self._option.trt_option.set_shape(tensor_name, min_shape,
|
||||
opt_shape, max_shape)
|
||||
|
||||
def set_trt_cache_file(self, cache_file_path):
|
||||
@@ -452,17 +464,26 @@ class RuntimeOption:
|
||||
|
||||
:param cache_file_path: (str)Path of tensorrt cache file
|
||||
"""
|
||||
return self._option.set_trt_cache_file(cache_file_path)
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_trt_cache_file` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.serialize_file = {}` instead.".
|
||||
format(cache_file_path))
|
||||
self._option.trt_option.serialize_file = cache_file_path
|
||||
|
||||
def enable_trt_fp16(self):
|
||||
"""Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference.
|
||||
"""
|
||||
return self._option.enable_trt_fp16()
|
||||
logging.warning(
|
||||
"`RuntimeOption.enable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = True` instead."
|
||||
)
|
||||
self._option.trt_option.enable_fp16 = True
|
||||
|
||||
def disable_trt_fp16(self):
|
||||
"""Disable half precision inference while suing TensorRT backend.
|
||||
"""
|
||||
return self._option.disable_trt_fp16()
|
||||
logging.warning(
|
||||
"`RuntimeOption.disable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = False` instead."
|
||||
)
|
||||
self._option.trt_option.enable_fp16 = False
|
||||
|
||||
def enable_pinned_memory(self):
|
||||
"""Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
|
||||
@@ -482,12 +503,18 @@ class RuntimeOption:
|
||||
def set_trt_max_workspace_size(self, trt_max_workspace_size):
|
||||
"""Set max workspace size while using TensorRT backend.
|
||||
"""
|
||||
return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_trt_max_workspace_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_workspace_size = {}` instead.".
|
||||
format(trt_max_workspace_size))
|
||||
self._option.trt_option.max_workspace_size = trt_max_workspace_size
|
||||
|
||||
def set_trt_max_batch_size(self, trt_max_batch_size):
|
||||
"""Set max batch size while using TensorRT backend.
|
||||
"""
|
||||
return self._option.set_trt_max_batch_size(trt_max_batch_size)
|
||||
logging.warning(
|
||||
"`RuntimeOption.set_trt_max_batch_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_batch_size = {}` instead.".
|
||||
format(trt_max_batch_size))
|
||||
self._option.trt_option.max_batch_size = trt_max_batch_size
|
||||
|
||||
def enable_paddle_trt_collect_shape(self):
|
||||
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
|
||||
@@ -558,6 +585,14 @@ class RuntimeOption:
|
||||
"""
|
||||
return self._option.ort_option
|
||||
|
||||
@property
|
||||
def trt_option(self):
|
||||
"""Get TrtBackendOption object to configure TensorRT backend
|
||||
|
||||
:return TrtBackendOption
|
||||
"""
|
||||
return self._option.trt_option
|
||||
|
||||
def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50):
|
||||
"""Set the profile mode as 'true'.
|
||||
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.
|
||||
|
@@ -162,7 +162,8 @@ optimization {
|
||||
gpu_execution_accelerator : [
|
||||
{
|
||||
name : "tensorrt"
|
||||
# Use FP16 inference in TensorRT. You can also choose: trt_fp32, trt_int8
|
||||
# Use FP16 inference in TensorRT. You can also choose: trt_fp32
|
||||
# If the loaded model is a quantized model, this precision will be int8 automatically
|
||||
parameters { key: "precision" value: "trt_fp16" }
|
||||
}
|
||||
]
|
||||
|
@@ -162,7 +162,8 @@ optimization {
|
||||
gpu_execution_accelerator : [
|
||||
{
|
||||
name : "tensorrt"
|
||||
# 使用TensorRT的FP16推理,其他可选项为: trt_fp32、trt_int8
|
||||
# 使用TensorRT的FP16推理,其他可选项为: trt_fp32
|
||||
# 如果加载的是量化模型,此精度设置无效,会默认使用int8进行推理
|
||||
parameters { key: "precision" value: "trt_fp16" }
|
||||
}
|
||||
]
|
||||
|
@@ -168,7 +168,10 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model,
|
||||
}
|
||||
|
||||
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
||||
: BackendModel(triton_model), model_load_(false), main_runtime_(nullptr), is_clone_(true) {
|
||||
: BackendModel(triton_model),
|
||||
model_load_(false),
|
||||
main_runtime_(nullptr),
|
||||
is_clone_(true) {
|
||||
// Create runtime options that will be cloned and used for each
|
||||
// instance when creating that instance's runtime.
|
||||
runtime_options_.reset(new fastdeploy::RuntimeOption());
|
||||
@@ -232,7 +235,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
||||
int num_streams;
|
||||
THROW_IF_BACKEND_MODEL_ERROR(
|
||||
ParseIntValue(value_string, &num_streams));
|
||||
runtime_options_->SetOpenVINOStreams(num_streams);
|
||||
runtime_options_->openvino_option.num_streams = num_streams;
|
||||
} else if (param_key == "is_clone") {
|
||||
THROW_IF_BACKEND_MODEL_ERROR(
|
||||
ParseBoolValue(value_string, &is_clone_));
|
||||
@@ -271,11 +274,11 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
||||
std::vector<int32_t> shape;
|
||||
FDParseShape(params, input_name, &shape);
|
||||
if (name == "min_shape") {
|
||||
runtime_options_->trt_min_shape[input_name] = shape;
|
||||
runtime_options_->trt_option.min_shape[input_name] = shape;
|
||||
} else if (name == "max_shape") {
|
||||
runtime_options_->trt_max_shape[input_name] = shape;
|
||||
runtime_options_->trt_option.max_shape[input_name] = shape;
|
||||
} else {
|
||||
runtime_options_->trt_opt_shape[input_name] = shape;
|
||||
runtime_options_->trt_option.opt_shape[input_name] = shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -292,12 +295,10 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
||||
std::transform(value_string.begin(), value_string.end(),
|
||||
value_string.begin(), ::tolower);
|
||||
if (value_string == "trt_fp16") {
|
||||
runtime_options_->EnableTrtFP16();
|
||||
} else if (value_string == "trt_int8") {
|
||||
// TODO(liqi): use EnableTrtINT8
|
||||
runtime_options_->trt_enable_int8 = true;
|
||||
runtime_options_->trt_option.enable_fp16 = true;
|
||||
} else if (value_string == "pd_fp16") {
|
||||
// TODO(liqi): paddle inference don't currently have interface for fp16.
|
||||
// TODO(liqi): paddle inference don't currently have interface
|
||||
// for fp16.
|
||||
}
|
||||
// } else if( param_key == "max_batch_size") {
|
||||
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
|
||||
@@ -307,7 +308,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
||||
// value_string,
|
||||
// &runtime_options_->trt_max_workspace_size));
|
||||
} else if (param_key == "cache_file") {
|
||||
runtime_options_->SetTrtCacheFile(value_string);
|
||||
runtime_options_->trt_option.serialize_file = value_string;
|
||||
} else if (param_key == "use_paddle") {
|
||||
runtime_options_->EnablePaddleToTrt();
|
||||
} else if (param_key == "use_paddle_log") {
|
||||
@@ -330,7 +331,6 @@ TRITONSERVER_Error* ModelState::LoadModel(
|
||||
const int32_t instance_group_device_id, std::string* model_path,
|
||||
std::string* params_path, fastdeploy::Runtime** runtime,
|
||||
cudaStream_t stream) {
|
||||
|
||||
// FastDeploy Runtime creation is not thread-safe, so multiple creations
|
||||
// are serialized with a global lock.
|
||||
// The Clone interface can be invoked only when the main_runtime_ is created.
|
||||
@@ -339,7 +339,8 @@ TRITONSERVER_Error* ModelState::LoadModel(
|
||||
|
||||
if (model_load_ && is_clone_) {
|
||||
if (main_runtime_ == nullptr) {
|
||||
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND,
|
||||
return TRITONSERVER_ErrorNew(
|
||||
TRITONSERVER_ERROR_NOT_FOUND,
|
||||
std::string("main_runtime is nullptr").c_str());
|
||||
}
|
||||
*runtime = main_runtime_->Clone((void*)stream, instance_group_device_id);
|
||||
@@ -367,16 +368,16 @@ TRITONSERVER_Error* ModelState::LoadModel(
|
||||
if (not exists) {
|
||||
return TRITONSERVER_ErrorNew(
|
||||
TRITONSERVER_ERROR_NOT_FOUND,
|
||||
std::string("Paddle params should be named as 'model.pdiparams' or "
|
||||
std::string(
|
||||
"Paddle params should be named as 'model.pdiparams' or "
|
||||
"not provided.'")
|
||||
.c_str());
|
||||
}
|
||||
runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE;
|
||||
runtime_options_->model_file = *model_path;
|
||||
runtime_options_->params_file = *params_path;
|
||||
runtime_options_->SetModelPath(*model_path, *params_path,
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
} else {
|
||||
runtime_options_->model_format = fastdeploy::ModelFormat::ONNX;
|
||||
runtime_options_->model_file = *model_path;
|
||||
runtime_options_->SetModelPath(*model_path, "",
|
||||
fastdeploy::ModelFormat::ONNX);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1074,8 +1075,7 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
|
||||
{TRITONSERVER_MEMORY_CPU, 0}};
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(
|
||||
collector->ProcessTensor(
|
||||
RETURN_IF_ERROR(collector->ProcessTensor(
|
||||
input_name, nullptr, 0, allowed_input_types, &input_buffer,
|
||||
&batchn_byte_size, &memory_type, &memory_type_id));
|
||||
|
||||
@@ -1089,9 +1089,9 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
|
||||
}
|
||||
|
||||
fastdeploy::FDTensor fdtensor(in_name);
|
||||
fdtensor.SetExternalData(
|
||||
batchn_shape, ConvertDataTypeToFD(input_datatype),
|
||||
const_cast<char*>(input_buffer), device, device_id);
|
||||
fdtensor.SetExternalData(batchn_shape, ConvertDataTypeToFD(input_datatype),
|
||||
const_cast<char*>(input_buffer), device,
|
||||
device_id);
|
||||
runtime_->BindInputTensor(in_name, fdtensor);
|
||||
}
|
||||
|
||||
@@ -1130,8 +1130,7 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors(
|
||||
for (auto& output_name : output_names_) {
|
||||
auto* output_tensor = runtime_->GetOutputTensor(output_name);
|
||||
if (output_tensor == nullptr) {
|
||||
RETURN_IF_ERROR(
|
||||
TRITONSERVER_ErrorNew(
|
||||
RETURN_IF_ERROR(TRITONSERVER_ErrorNew(
|
||||
TRITONSERVER_ERROR_INTERNAL,
|
||||
(std::string("output tensor '") + output_name + "' is not found")
|
||||
.c_str()));
|
||||
@@ -1145,8 +1144,8 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors(
|
||||
responder.ProcessTensor(
|
||||
output_tensor->name, ConvertFDType(output_tensor->dtype),
|
||||
output_tensor->shape,
|
||||
reinterpret_cast<char*>(output_tensor->MutableData()),
|
||||
memory_type, memory_type_id);
|
||||
reinterpret_cast<char*>(output_tensor->MutableData()), memory_type,
|
||||
memory_type_id);
|
||||
}
|
||||
|
||||
// Finalize and wait for any pending buffer copies.
|
||||
|
Reference in New Issue
Block a user