[Other] Deprecate some option api and parameters (#1243)

* Optimize Poros backend

* fix error

* Add more pybind

* fix conflicts

* add some deprecate notices

* [Other] Deprecate some apis in RuntimeOption (#1240)

* Deprecate more options

* modify serving

* Update option.h

* fix tensorrt error

* Update option_pybind.cc

* Update option_pybind.cc

* Fix error in serving

* fix word spell error
This commit is contained in:
Jason
2023-02-07 17:57:46 +08:00
committed by GitHub
parent a18cc0f94c
commit 713afe7f1c
15 changed files with 380 additions and 229 deletions

View File

@@ -48,6 +48,8 @@ enum LitePowerMode {
LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode
}; };
/*! @brief Option object to configure Paddle Lite backend
*/
struct LiteBackendOption { struct LiteBackendOption {
/// Paddle Lite power mode for mobile device. /// Paddle Lite power mode for mobile device.
int power_mode = 3; int power_mode = 3;
@@ -55,12 +57,20 @@ struct LiteBackendOption {
int cpu_threads = 1; int cpu_threads = 1;
/// Enable use half precision /// Enable use half precision
bool enable_fp16 = false; bool enable_fp16 = false;
/// Enable use int8 precision for quantized model /// Inference device, Paddle Lite support CPU/KUNLUNXIN/TIMVX/ASCEND
bool enable_int8 = false;
Device device = Device::CPU; Device device = Device::CPU;
/// Index of inference device
int device_id = 0;
// optimized model dir for CxxConfig int kunlunxin_l3_workspace_size = 0xfffc00;
bool kunlunxin_locked = false;
bool kunlunxin_autotune = true;
std::string kunlunxin_autotune_file = "";
std::string kunlunxin_precision = "int16";
bool kunlunxin_adaptive_seqlen = false;
bool kunlunxin_enable_multi_stream = false;
/// Optimized model dir for CxxConfig
std::string optimized_model_dir = ""; std::string optimized_model_dir = "";
std::string nnadapter_subgraph_partition_config_path = ""; std::string nnadapter_subgraph_partition_config_path = "";
std::string nnadapter_subgraph_partition_config_buffer = ""; std::string nnadapter_subgraph_partition_config_buffer = "";
@@ -70,13 +80,5 @@ struct LiteBackendOption {
std::map<std::string, std::vector<std::vector<int64_t>>> std::map<std::string, std::vector<std::vector<int64_t>>>
nnadapter_dynamic_shape_info = {{"", {{0}}}}; nnadapter_dynamic_shape_info = {{"", {{0}}}};
std::vector<std::string> nnadapter_device_names = {}; std::vector<std::string> nnadapter_device_names = {};
int device_id = 0;
int kunlunxin_l3_workspace_size = 0xfffc00;
bool kunlunxin_locked = false;
bool kunlunxin_autotune = true;
std::string kunlunxin_autotune_file = "";
std::string kunlunxin_precision = "int16";
bool kunlunxin_adaptive_seqlen = false;
bool kunlunxin_enable_multi_stream = false;
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -23,7 +23,6 @@ void BindLiteOption(pybind11::module& m) {
.def_readwrite("power_mode", &LiteBackendOption::power_mode) .def_readwrite("power_mode", &LiteBackendOption::power_mode)
.def_readwrite("cpu_threads", &LiteBackendOption::cpu_threads) .def_readwrite("cpu_threads", &LiteBackendOption::cpu_threads)
.def_readwrite("enable_fp16", &LiteBackendOption::enable_fp16) .def_readwrite("enable_fp16", &LiteBackendOption::enable_fp16)
.def_readwrite("enable_int8", &LiteBackendOption::enable_int8)
.def_readwrite("device", &LiteBackendOption::device) .def_readwrite("device", &LiteBackendOption::device)
.def_readwrite("optimized_model_dir", .def_readwrite("optimized_model_dir",
&LiteBackendOption::optimized_model_dir) &LiteBackendOption::optimized_model_dir)

View File

@@ -23,9 +23,13 @@
#include <set> #include <set>
namespace fastdeploy { namespace fastdeploy {
/*! @brief Option object to configure OpenVINO backend
*/
struct OpenVINOBackendOption { struct OpenVINOBackendOption {
std::string device = "CPU"; std::string device = "CPU";
int cpu_thread_num = -1; int cpu_thread_num = -1;
/// Number of streams while use OpenVINO
int num_streams = 0; int num_streams = 0;
/** /**

View File

@@ -22,20 +22,30 @@
#include <map> #include <map>
namespace fastdeploy { namespace fastdeploy {
/*! @brief Option object to configure ONNX Runtime backend
*/
struct OrtBackendOption { struct OrtBackendOption {
// -1 means default /*
// 0: ORT_DISABLE_ALL * @brief Level of graph optimization, -1: mean default(Enable all the optimization strategy)/0: disable all the optimization strategy/1: enable basic strategy/2:enable extend strategy/99: enable all
// 1: ORT_ENABLE_BASIC */
// 2: ORT_ENABLE_EXTENDED
// 99: ORT_ENABLE_ALL (enable some custom optimizations e.g bert)
int graph_optimization_level = -1; int graph_optimization_level = -1;
/*
* @brief Number of threads to execute the operator, -1: default
*/
int intra_op_num_threads = -1; int intra_op_num_threads = -1;
/*
* @brief Number of threads to execute the graph, -1: default. This parameter only will bring effects while the `OrtBackendOption::execution_mode` set to 1.
*/
int inter_op_num_threads = -1; int inter_op_num_threads = -1;
// 0: ORT_SEQUENTIAL /*
// 1: ORT_PARALLEL * @brief Execution mode for the graph, -1: default(Sequential mode)/0: Sequential mode, execute the operators in graph one by one. /1: Parallel mode, execute the operators in graph parallelly.
*/
int execution_mode = -1; int execution_mode = -1;
/// Inference device, OrtBackend supports CPU/GPU
Device device = Device::CPU; Device device = Device::CPU;
/// Inference device id
int device_id = 0; int device_id = 0;
void* external_stream_ = nullptr; void* external_stream_ = nullptr;
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -22,6 +22,8 @@
namespace fastdeploy { namespace fastdeploy {
/*! @brief Option object to configure Poros backend
*/
struct PorosBackendOption { struct PorosBackendOption {
Device device = Device::CPU; Device device = Device::CPU;
int device_id = 0; int device_id = 0;

View File

@@ -21,23 +21,64 @@
namespace fastdeploy { namespace fastdeploy {
/*! @brief Option object to configure TensorRT backend
*/
struct TrtBackendOption { struct TrtBackendOption {
std::string model_file = ""; // Path of model file /// `max_batch_size`, it's deprecated in TensorRT 8.x
std::string params_file = ""; // Path of parameters file, can be empty
// format of input model
ModelFormat model_format = ModelFormat::AUTOREC;
int gpu_id = 0;
bool enable_fp16 = false;
bool enable_int8 = false;
size_t max_batch_size = 32; size_t max_batch_size = 32;
/// `max_workspace_size` for TensorRT
size_t max_workspace_size = 1 << 30; size_t max_workspace_size = 1 << 30;
/*
* @brief Enable half precison inference, on some device not support half precision, it will fallback to float32 mode
*/
bool enable_fp16 = false;
/** \brief Set shape range of input tensor for the model that contain dynamic input shape while using TensorRT backend
*
* \param[in] tensor_name The name of input for the model which is dynamic shape
* \param[in] min The minimal shape for the input tensor
* \param[in] opt The optimized shape for the input tensor, just set the most common shape, if set as default value, it will keep same with min_shape
* \param[in] max The maximum shape for the input tensor, if set as default value, it will keep same with min_shape
*/
void SetShape(const std::string& tensor_name,
const std::vector<int32_t>& min,
const std::vector<int32_t>& opt,
const std::vector<int32_t>& max) {
min_shape[tensor_name].clear();
max_shape[tensor_name].clear();
opt_shape[tensor_name].clear();
min_shape[tensor_name].assign(min.begin(), min.end());
if (opt.size() == 0) {
opt_shape[tensor_name].assign(min.begin(), min.end());
} else {
opt_shape[tensor_name].assign(opt.begin(), opt.end());
}
if (max.size() == 0) {
max_shape[tensor_name].assign(min.begin(), min.end());
} else {
max_shape[tensor_name].assign(max.begin(), max.end());
}
}
/**
* @brief Set cache file path while use TensorRT backend. Loadding a Paddle/ONNX model and initialize TensorRT will take a long time, by this interface it will save the tensorrt engine to `cache_file_path`, and load it directly while execute the code again
*/
std::string serialize_file = "";
// The below parameters may be removed in next version, please do not
// visit or use them directly
std::map<std::string, std::vector<int32_t>> max_shape; std::map<std::string, std::vector<int32_t>> max_shape;
std::map<std::string, std::vector<int32_t>> min_shape; std::map<std::string, std::vector<int32_t>> min_shape;
std::map<std::string, std::vector<int32_t>> opt_shape; std::map<std::string, std::vector<int32_t>> opt_shape;
std::string serialize_file = "";
bool enable_pinned_memory = false; bool enable_pinned_memory = false;
void* external_stream_ = nullptr; void* external_stream_ = nullptr;
int gpu_id = 0;
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
// format of input model
ModelFormat model_format = ModelFormat::AUTOREC;
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
#include "fastdeploy/runtime/backends/tensorrt/option.h"
namespace fastdeploy {
void BindTrtOption(pybind11::module& m) {
pybind11::class_<TrtBackendOption>(m, "TrtBackendOption")
.def(pybind11::init())
.def_readwrite("enable_fp16", &TrtBackendOption::enable_fp16)
.def_readwrite("max_batch_size", &TrtBackendOption::max_batch_size)
.def_readwrite("max_workspace_size",
&TrtBackendOption::max_workspace_size)
.def_readwrite("serialize_file", &TrtBackendOption::serialize_file)
.def("set_shape", &TrtBackendOption::SetShape);
}
} // namespace fastdeploy

View File

@@ -19,12 +19,14 @@ namespace fastdeploy {
void BindLiteOption(pybind11::module& m); void BindLiteOption(pybind11::module& m);
void BindOpenVINOOption(pybind11::module& m); void BindOpenVINOOption(pybind11::module& m);
void BindOrtOption(pybind11::module& m); void BindOrtOption(pybind11::module& m);
void BindTrtOption(pybind11::module& m);
void BindPorosOption(pybind11::module& m); void BindPorosOption(pybind11::module& m);
void BindOption(pybind11::module& m) { void BindOption(pybind11::module& m) {
BindLiteOption(m); BindLiteOption(m);
BindOpenVINOOption(m); BindOpenVINOOption(m);
BindOrtOption(m); BindOrtOption(m);
BindTrtOption(m);
BindPorosOption(m); BindPorosOption(m);
pybind11::class_<RuntimeOption>(m, "RuntimeOption") pybind11::class_<RuntimeOption>(m, "RuntimeOption")
@@ -40,47 +42,22 @@ void BindOption(pybind11::module& m) {
.def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option) .def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option)
.def_readwrite("openvino_option", &RuntimeOption::openvino_option) .def_readwrite("openvino_option", &RuntimeOption::openvino_option)
.def_readwrite("ort_option", &RuntimeOption::ort_option) .def_readwrite("ort_option", &RuntimeOption::ort_option)
.def_readwrite("trt_option", &RuntimeOption::trt_option)
.def_readwrite("poros_option", &RuntimeOption::poros_option) .def_readwrite("poros_option", &RuntimeOption::poros_option)
.def("set_external_stream", &RuntimeOption::SetExternalStream) .def("set_external_stream", &RuntimeOption::SetExternalStream)
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
.def("use_poros_backend", &RuntimeOption::UsePorosBackend) .def("use_poros_backend", &RuntimeOption::UsePorosBackend)
.def("use_ort_backend", &RuntimeOption::UseOrtBackend) .def("use_ort_backend", &RuntimeOption::UseOrtBackend)
.def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel)
.def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
.def("use_lite_backend", &RuntimeOption::UseLiteBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend)
.def("set_lite_device_names", &RuntimeOption::SetLiteDeviceNames)
.def("set_lite_context_properties",
&RuntimeOption::SetLiteContextProperties)
.def("set_lite_model_cache_dir", &RuntimeOption::SetLiteModelCacheDir)
.def("set_lite_dynamic_shape_info",
&RuntimeOption::SetLiteDynamicShapeInfo)
.def("set_lite_subgraph_partition_path",
&RuntimeOption::SetLiteSubgraphPartitionPath)
.def("set_lite_mixed_precision_quantization_config_path",
&RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath)
.def("set_lite_subgraph_partition_config_buffer",
&RuntimeOption::SetLiteSubgraphPartitionConfigBuffer)
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
.def("set_openvino_cpu_operators",
&RuntimeOption::SetOpenVINOCpuOperators)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size", .def("set_paddle_mkldnn_cache_size",
&RuntimeOption::SetPaddleMKLDNNCacheSize) &RuntimeOption::SetPaddleMKLDNNCacheSize)
.def("enable_lite_fp16", &RuntimeOption::EnableLiteFP16)
.def("disable_lite_fp16", &RuntimeOption::DisableLiteFP16)
.def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode)
.def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape)
.def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize)
.def("set_trt_max_batch_size", &RuntimeOption::SetTrtMaxBatchSize)
.def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt) .def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt)
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape", .def("enable_paddle_trt_collect_shape",
@@ -103,15 +80,6 @@ void BindOption(pybind11::module& m) {
.def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num) .def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num)
.def_readwrite("device_id", &RuntimeOption::device_id) .def_readwrite("device_id", &RuntimeOption::device_id)
.def_readwrite("device", &RuntimeOption::device) .def_readwrite("device", &RuntimeOption::device)
.def_readwrite("trt_max_shape", &RuntimeOption::trt_max_shape)
.def_readwrite("trt_opt_shape", &RuntimeOption::trt_opt_shape)
.def_readwrite("trt_min_shape", &RuntimeOption::trt_min_shape)
.def_readwrite("trt_serialize_file", &RuntimeOption::trt_serialize_file)
.def_readwrite("trt_enable_fp16", &RuntimeOption::trt_enable_fp16)
.def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8)
.def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size)
.def_readwrite("trt_max_workspace_size",
&RuntimeOption::trt_max_workspace_size)
.def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num) .def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num)
.def_readwrite("ipu_micro_batch_size", .def_readwrite("ipu_micro_batch_size",
&RuntimeOption::ipu_micro_batch_size) &RuntimeOption::ipu_micro_batch_size)

View File

@@ -244,17 +244,9 @@ void Runtime::CreatePaddleBackend() {
if (pd_option.use_gpu && option.pd_enable_trt) { if (pd_option.use_gpu && option.pd_enable_trt) {
pd_option.enable_trt = true; pd_option.enable_trt = true;
pd_option.collect_shape = option.pd_collect_shape; pd_option.collect_shape = option.pd_collect_shape;
auto trt_option = TrtBackendOption(); pd_option.trt_option = option.trt_option;
trt_option.gpu_id = option.device_id; pd_option.trt_option.gpu_id = option.device_id;
trt_option.enable_fp16 = option.trt_enable_fp16; pd_option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
trt_option.max_batch_size = option.trt_max_batch_size;
trt_option.max_workspace_size = option.trt_max_workspace_size;
trt_option.max_shape = option.trt_max_shape;
trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_option = trt_option;
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_; pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
} }
#endif #endif
@@ -339,41 +331,33 @@ void Runtime::CreateTrtBackend() {
"TrtBackend only support model format of ModelFormat::PADDLE / " "TrtBackend only support model format of ModelFormat::PADDLE / "
"ModelFormat::ONNX."); "ModelFormat::ONNX.");
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
auto trt_option = TrtBackendOption(); option.trt_option.model_file = option.model_file;
trt_option.model_file = option.model_file; option.trt_option.params_file = option.params_file;
trt_option.params_file = option.params_file; option.trt_option.model_format = option.model_format;
trt_option.model_format = option.model_format; option.trt_option.gpu_id = option.device_id;
trt_option.gpu_id = option.device_id; option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
trt_option.enable_fp16 = option.trt_enable_fp16; option.trt_option.external_stream_ = option.external_stream_;
trt_option.enable_int8 = option.trt_enable_int8;
trt_option.max_batch_size = option.trt_max_batch_size;
trt_option.max_workspace_size = option.trt_max_workspace_size;
trt_option.max_shape = option.trt_max_shape;
trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;
trt_option.external_stream_ = option.external_stream_;
backend_ = utils::make_unique<TrtBackend>(); backend_ = utils::make_unique<TrtBackend>();
auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get()); auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option; casted_backend->benchmark_option_ = option.benchmark_option;
if (option.model_format == ModelFormat::ONNX) { if (option.model_format == ModelFormat::ONNX) {
if (option.model_from_memory_) { if (option.model_from_memory_) {
FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option), FDASSERT(
casted_backend->InitFromOnnx(option.model_file, option.trt_option),
"Load model from ONNX failed while initliazing TrtBackend."); "Load model from ONNX failed while initliazing TrtBackend.");
ReleaseModelMemoryBuffer(); ReleaseModelMemoryBuffer();
} else { } else {
std::string model_buffer = ""; std::string model_buffer = "";
FDASSERT(ReadBinaryFromFile(option.model_file, &model_buffer), FDASSERT(ReadBinaryFromFile(option.model_file, &model_buffer),
"Fail to read binary from model file"); "Fail to read binary from model file");
FDASSERT(casted_backend->InitFromOnnx(model_buffer, trt_option), FDASSERT(casted_backend->InitFromOnnx(model_buffer, option.trt_option),
"Load model from ONNX failed while initliazing TrtBackend."); "Load model from ONNX failed while initliazing TrtBackend.");
} }
} else { } else {
if (option.model_from_memory_) { if (option.model_from_memory_) {
FDASSERT(casted_backend->InitFromPaddle(option.model_file, FDASSERT(casted_backend->InitFromPaddle(
option.params_file, trt_option), option.model_file, option.params_file, option.trt_option),
"Load model from Paddle failed while initliazing TrtBackend."); "Load model from Paddle failed while initliazing TrtBackend.");
ReleaseModelMemoryBuffer(); ReleaseModelMemoryBuffer();
} else { } else {
@@ -384,7 +368,7 @@ void Runtime::CreateTrtBackend() {
FDASSERT(ReadBinaryFromFile(option.params_file, &params_buffer), FDASSERT(ReadBinaryFromFile(option.params_file, &params_buffer),
"Fail to read binary from parameter file"); "Fail to read binary from parameter file");
FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer, FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer,
trt_option), option.trt_option),
"Load model from Paddle failed while initliazing TrtBackend."); "Load model from Paddle failed while initliazing TrtBackend.");
} }
} }
@@ -505,9 +489,10 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
} }
option.poros_option.device = option.device; option.poros_option.device = option.device;
option.poros_option.device_id = option.device_id; option.poros_option.device_id = option.device_id;
option.poros_option.enable_fp16 = option.trt_enable_fp16; option.poros_option.enable_fp16 = option.trt_option.enable_fp16;
option.poros_option.max_batch_size = option.trt_max_batch_size; option.poros_option.max_batch_size = option.trt_option.max_batch_size;
option.poros_option.max_workspace_size = option.trt_max_workspace_size; option.poros_option.max_workspace_size = option.trt_option.max_workspace_size;
backend_ = utils::make_unique<PorosBackend>(); backend_ = utils::make_unique<PorosBackend>();
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get()); auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
FDASSERT( FDASSERT(

View File

@@ -102,6 +102,10 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
} }
void RuntimeOption::SetOrtGraphOptLevel(int level) { void RuntimeOption::SetOrtGraphOptLevel(int level) {
FDWARNING << "`RuntimeOption::SetOrtGraphOptLevel` will be removed in "
"v1.2.0, please modify its member variables directly, e.g "
"`runtime_option.ort_option.graph_optimization_level = 99`."
<< std::endl;
std::vector<int> supported_level{-1, 0, 1, 2}; std::vector<int> supported_level{-1, 0, 1, 2};
auto valid_level = std::find(supported_level.begin(), supported_level.end(), auto valid_level = std::find(supported_level.begin(), supported_level.end(),
level) != supported_level.end(); level) != supported_level.end();
@@ -203,67 +207,127 @@ void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) {
} }
void RuntimeOption::SetOpenVINODevice(const std::string& name) { void RuntimeOption::SetOpenVINODevice(const std::string& name) {
openvino_option.device = name; FDWARNING << "`RuntimeOption::SetOpenVINODevice` will be removed in v1.2.0, "
"please use `RuntimeOption.openvino_option.SetDeivce(const "
"std::string&)` instead."
<< std::endl;
openvino_option.SetDevice(name);
} }
void RuntimeOption::EnableLiteFP16() { paddle_lite_option.enable_fp16 = true; } void RuntimeOption::EnableLiteFP16() {
FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, "
"please modify its member variables directly, e.g "
"`runtime_option.paddle_lite_option.enable_fp16 = true`"
<< std::endl;
paddle_lite_option.enable_fp16 = true;
}
void RuntimeOption::DisableLiteFP16() { void RuntimeOption::DisableLiteFP16() {
FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, "
"please modify its member variables directly, e.g "
"`runtime_option.paddle_lite_option.enable_fp16 = false`"
<< std::endl;
paddle_lite_option.enable_fp16 = false; paddle_lite_option.enable_fp16 = false;
} }
void RuntimeOption::EnableLiteInt8() { paddle_lite_option.enable_int8 = true; } void RuntimeOption::EnableLiteInt8() {
FDWARNING << "RuntimeOption::EnableLiteInt8 is a useless api, this calling "
"will not bring any effects, and will be removed in v1.2.0. if "
"you load a quantized model, it will automatically run with "
"int8 mode; otherwise it will run with float mode."
<< std::endl;
}
void RuntimeOption::DisableLiteInt8() { void RuntimeOption::DisableLiteInt8() {
paddle_lite_option.enable_int8 = false; FDWARNING << "RuntimeOption::DisableLiteInt8 is a useless api, this calling "
"will not bring any effects, and will be removed in v1.2.0. if "
"you load a quantized model, it will automatically run with "
"int8 mode; otherwise it will run with float mode."
<< std::endl;
} }
void RuntimeOption::SetLitePowerMode(LitePowerMode mode) { void RuntimeOption::SetLitePowerMode(LitePowerMode mode) {
FDWARNING << "`RuntimeOption::SetLitePowerMode` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.power_mode = 3;`"
<< std::endl;
paddle_lite_option.power_mode = mode; paddle_lite_option.power_mode = mode;
} }
void RuntimeOption::SetLiteOptimizedModelDir( void RuntimeOption::SetLiteOptimizedModelDir(
const std::string& optimized_model_dir) { const std::string& optimized_model_dir) {
FDWARNING
<< "`RuntimeOption::SetLiteOptimizedModelDir` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.optimized_model_dir = \"...\"`"
<< std::endl;
paddle_lite_option.optimized_model_dir = optimized_model_dir; paddle_lite_option.optimized_model_dir = optimized_model_dir;
} }
void RuntimeOption::SetLiteSubgraphPartitionPath( void RuntimeOption::SetLiteSubgraphPartitionPath(
const std::string& nnadapter_subgraph_partition_config_path) { const std::string& nnadapter_subgraph_partition_config_path) {
FDWARNING << "`RuntimeOption::SetLiteSubgraphPartitionPath` will be removed "
"in v1.2.0, please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.nnadapter_subgraph_"
"partition_config_path = \"...\";` "
<< std::endl;
paddle_lite_option.nnadapter_subgraph_partition_config_path = paddle_lite_option.nnadapter_subgraph_partition_config_path =
nnadapter_subgraph_partition_config_path; nnadapter_subgraph_partition_config_path;
} }
void RuntimeOption::SetLiteSubgraphPartitionConfigBuffer( void RuntimeOption::SetLiteSubgraphPartitionConfigBuffer(
const std::string& nnadapter_subgraph_partition_config_buffer) { const std::string& nnadapter_subgraph_partition_config_buffer) {
FDWARNING
<< "`RuntimeOption::SetLiteSubgraphPartitionConfigBuffer` will be "
"removed in v1.2.0, please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.nnadapter_subgraph_partition_"
"config_buffer = ...`"
<< std::endl;
paddle_lite_option.nnadapter_subgraph_partition_config_buffer = paddle_lite_option.nnadapter_subgraph_partition_config_buffer =
nnadapter_subgraph_partition_config_buffer; nnadapter_subgraph_partition_config_buffer;
} }
void RuntimeOption::SetLiteDeviceNames(
const std::vector<std::string>& nnadapter_device_names) {
paddle_lite_option.nnadapter_device_names = nnadapter_device_names;
}
void RuntimeOption::SetLiteContextProperties( void RuntimeOption::SetLiteContextProperties(
const std::string& nnadapter_context_properties) { const std::string& nnadapter_context_properties) {
FDWARNING << "`RuntimeOption::SetLiteContextProperties` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.nnadapter_context_"
"properties = ...`"
<< std::endl;
paddle_lite_option.nnadapter_context_properties = paddle_lite_option.nnadapter_context_properties =
nnadapter_context_properties; nnadapter_context_properties;
} }
void RuntimeOption::SetLiteModelCacheDir( void RuntimeOption::SetLiteModelCacheDir(
const std::string& nnadapter_model_cache_dir) { const std::string& nnadapter_model_cache_dir) {
FDWARNING
<< "`RuntimeOption::SetLiteModelCacheDir` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.nnadapter_model_cache_dir = ...`"
<< std::endl;
paddle_lite_option.nnadapter_model_cache_dir = nnadapter_model_cache_dir; paddle_lite_option.nnadapter_model_cache_dir = nnadapter_model_cache_dir;
} }
void RuntimeOption::SetLiteDynamicShapeInfo( void RuntimeOption::SetLiteDynamicShapeInfo(
const std::map<std::string, std::vector<std::vector<int64_t>>>& const std::map<std::string, std::vector<std::vector<int64_t>>>&
nnadapter_dynamic_shape_info) { nnadapter_dynamic_shape_info) {
FDWARNING << "`RuntimeOption::SetLiteDynamicShapeInfo` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.paddle_lite_option."
"nnadapter_dynamic_shape_info = ...`"
<< std::endl;
paddle_lite_option.nnadapter_dynamic_shape_info = paddle_lite_option.nnadapter_dynamic_shape_info =
nnadapter_dynamic_shape_info; nnadapter_dynamic_shape_info;
} }
void RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath( void RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath(
const std::string& nnadapter_mixed_precision_quantization_config_path) { const std::string& nnadapter_mixed_precision_quantization_config_path) {
FDWARNING
<< "`RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath` will be "
"removed in v1.2.0, please modify its member variable directly, e.g "
"`runtime_option.paddle_lite_option.paddle_lite_option.nnadapter_"
"mixed_precision_quantization_config_path = ...`"
<< std::endl;
paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = paddle_lite_option.nnadapter_mixed_precision_quantization_config_path =
nnadapter_mixed_precision_quantization_config_path; nnadapter_mixed_precision_quantization_config_path;
} }
@@ -272,42 +336,60 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name,
const std::vector<int32_t>& min_shape, const std::vector<int32_t>& min_shape,
const std::vector<int32_t>& opt_shape, const std::vector<int32_t>& opt_shape,
const std::vector<int32_t>& max_shape) { const std::vector<int32_t>& max_shape) {
trt_min_shape[input_name].clear(); FDWARNING << "`RuntimeOption::SetTrtInputShape` will be removed in v1.2.0, "
trt_max_shape[input_name].clear(); "please use `RuntimeOption.trt_option.SetShape()` instead."
trt_opt_shape[input_name].clear(); << std::endl;
trt_min_shape[input_name].assign(min_shape.begin(), min_shape.end()); trt_option.SetShape(input_name, min_shape, opt_shape, max_shape);
if (opt_shape.size() == 0) {
trt_opt_shape[input_name].assign(min_shape.begin(), min_shape.end());
} else {
trt_opt_shape[input_name].assign(opt_shape.begin(), opt_shape.end());
}
if (max_shape.size() == 0) {
trt_max_shape[input_name].assign(min_shape.begin(), min_shape.end());
} else {
trt_max_shape[input_name].assign(max_shape.begin(), max_shape.end());
}
} }
void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) { void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) {
trt_max_workspace_size = max_workspace_size; FDWARNING << "`RuntimeOption::SetTrtMaxWorkspaceSize` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`RuntimeOption.trt_option.max_workspace_size = "
<< max_workspace_size << "`." << std::endl;
trt_option.max_workspace_size = max_workspace_size;
} }
void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) { void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) {
trt_max_batch_size = max_batch_size; FDWARNING << "`RuntimeOption::SetTrtMaxBatchSize` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`RuntimeOption.trt_option.max_batch_size = "
<< max_batch_size << "`." << std::endl;
trt_option.max_batch_size = max_batch_size;
} }
void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } void RuntimeOption::EnableTrtFP16() {
FDWARNING << "`RuntimeOption::EnableTrtFP16` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.trt_option.enable_fp16 = true;`"
<< std::endl;
trt_option.enable_fp16 = true;
}
void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; } void RuntimeOption::DisableTrtFP16() {
FDWARNING << "`RuntimeOption::DisableTrtFP16` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.trt_option.enable_fp16 = false;`"
<< std::endl;
trt_option.enable_fp16 = false;
}
void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; } void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; }
void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; } void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; }
void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path; FDWARNING << "`RuntimeOption::SetTrtCacheFile` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.trt_option.serialize_file = \""
<< cache_file_path << "\"." << std::endl;
trt_option.serialize_file = cache_file_path;
} }
void RuntimeOption::SetOpenVINOStreams(int num_streams) { void RuntimeOption::SetOpenVINOStreams(int num_streams) {
FDWARNING << "`RuntimeOption::SetOpenVINOStreams` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`runtime_option.openvino_option.num_streams = "
<< num_streams << "`." << std::endl;
openvino_option.num_streams = num_streams; openvino_option.num_streams = num_streams;
} }

View File

@@ -211,12 +211,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
void SetLiteSubgraphPartitionConfigBuffer( void SetLiteSubgraphPartitionConfigBuffer(
const std::string& nnadapter_subgraph_partition_config_buffer); const std::string& nnadapter_subgraph_partition_config_buffer);
/**
* @brief Set device name for Paddle Lite backend.
*/
void
SetLiteDeviceNames(const std::vector<std::string>& nnadapter_device_names);
/** /**
* @brief Set context properties for Paddle Lite backend. * @brief Set context properties for Paddle Lite backend.
*/ */
@@ -381,6 +375,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool enable_pinned_memory = false; bool enable_pinned_memory = false;
/// Option to configure ONNX Runtime backend
OrtBackendOption ort_option; OrtBackendOption ort_option;
// ======Only for Paddle Backend===== // ======Only for Paddle Backend=====
@@ -401,20 +396,16 @@ struct FASTDEPLOY_DECL RuntimeOption {
float ipu_available_memory_proportion = 1.0; float ipu_available_memory_proportion = 1.0;
bool ipu_enable_half_partial = false; bool ipu_enable_half_partial = false;
// ======Only for Trt Backend======= /// Option to configure TensorRT backend
std::map<std::string, std::vector<int32_t>> trt_max_shape; TrtBackendOption trt_option;
std::map<std::string, std::vector<int32_t>> trt_min_shape;
std::map<std::string, std::vector<int32_t>> trt_opt_shape;
std::string trt_serialize_file = "";
bool trt_enable_fp16 = false;
bool trt_enable_int8 = false;
size_t trt_max_batch_size = 1;
size_t trt_max_workspace_size = 1 << 30;
// ======Only for PaddleTrt Backend======= // ======Only for PaddleTrt Backend=======
std::vector<std::string> trt_disabled_ops_{}; std::vector<std::string> trt_disabled_ops_{};
/// Option to configure Poros backend
PorosBackendOption poros_option; PorosBackendOption poros_option;
/// Option to configure OpenVINO backend
OpenVINOBackendOption openvino_option; OpenVINOBackendOption openvino_option;
// ======Only for RKNPU2 Backend======= // ======Only for RKNPU2 Backend=======
@@ -433,10 +424,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
std::string model_file = ""; std::string model_file = "";
std::string params_file = ""; std::string params_file = "";
bool model_from_memory_ = false; bool model_from_memory_ = false;
// format of input model /// format of input model
ModelFormat model_format = ModelFormat::PADDLE; ModelFormat model_format = ModelFormat::PADDLE;
// Benchmark option /// Benchmark option
benchmark::BenchmarkOption benchmark_option; benchmark::BenchmarkOption benchmark_option;
}; };

View File

@@ -154,6 +154,8 @@ class RuntimeOption:
"""Options for FastDeploy Runtime. """Options for FastDeploy Runtime.
""" """
__slots__ = ["_option"]
def __init__(self): def __init__(self):
"""Initialize a FastDeploy RuntimeOption object. """Initialize a FastDeploy RuntimeOption object.
""" """
@@ -266,7 +268,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_ort_graph_opt_level` will be deprecated in v1.2.0, please use `RuntimeOption.graph_optimize_level = 99` instead." "`RuntimeOption.set_ort_graph_opt_level` will be deprecated in v1.2.0, please use `RuntimeOption.graph_optimize_level = 99` instead."
) )
return self._option.set_ort_graph_opt_level(level) self._option.ort_option.graph_optimize_level = level
def use_paddle_backend(self): def use_paddle_backend(self):
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU. """Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
@@ -314,7 +316,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_context_properties` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_context_properties = ...` instead." "`RuntimeOption.set_lite_context_properties` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_context_properties = ...` instead."
) )
return self._option.set_lite_context_properties(context_properties) self._option.paddle_lite_option.nnadapter_context_properties = context_properties
def set_lite_model_cache_dir(self, model_cache_dir): def set_lite_model_cache_dir(self, model_cache_dir):
"""Set nnadapter model cache dir for Paddle Lite backend. """Set nnadapter model cache dir for Paddle Lite backend.
@@ -322,7 +324,8 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_model_cache_dir` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_model_cache_dir = ...` instead." "`RuntimeOption.set_lite_model_cache_dir` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_model_cache_dir = ...` instead."
) )
return self._option.set_lite_model_cache_dir(model_cache_dir)
self._option.paddle_lite_option.nnadapter_model_cache_dir = model_cache_dir
def set_lite_dynamic_shape_info(self, dynamic_shape_info): def set_lite_dynamic_shape_info(self, dynamic_shape_info):
""" Set nnadapter dynamic shape info for Paddle Lite backend. """ Set nnadapter dynamic shape info for Paddle Lite backend.
@@ -330,7 +333,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_dynamic_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_dynamic_shape_info = ...` instead." "`RuntimeOption.set_lite_dynamic_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_dynamic_shape_info = ...` instead."
) )
return self._option.set_lite_dynamic_shape_info(dynamic_shape_info) self._option.paddle_lite_option.nnadapter_dynamic_shape_info = dynamic_shape_info
def set_lite_subgraph_partition_path(self, subgraph_partition_path): def set_lite_subgraph_partition_path(self, subgraph_partition_path):
""" Set nnadapter subgraph partition path for Paddle Lite backend. """ Set nnadapter subgraph partition path for Paddle Lite backend.
@@ -338,8 +341,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_subgraph_partition_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_path = ...` instead." "`RuntimeOption.set_lite_subgraph_partition_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_path = ...` instead."
) )
return self._option.set_lite_subgraph_partition_path( self._option.paddle_lite_option.nnadapter_subgraph_partition_config_path = subgraph_partition_path
subgraph_partition_path)
def set_lite_subgraph_partition_config_buffer(self, def set_lite_subgraph_partition_config_buffer(self,
subgraph_partition_buffer): subgraph_partition_buffer):
@@ -348,8 +350,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_subgraph_partition_buffer` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = ...` instead." "`RuntimeOption.set_lite_subgraph_partition_buffer` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = ...` instead."
) )
return self._option.set_lite_subgraph_partition_config_buffer( self._option.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = subgraph_partition_buffer
subgraph_partition_buffer)
def set_lite_mixed_precision_quantization_config_path( def set_lite_mixed_precision_quantization_config_path(
self, mixed_precision_quantization_config_path): self, mixed_precision_quantization_config_path):
@@ -358,8 +359,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_lite_mixed_precision_quantization_config_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = ...` instead." "`RuntimeOption.set_lite_mixed_precision_quantization_config_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = ...` instead."
) )
return self._option.set_lite_mixed_precision_quantization_config_path( self._option.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = mixed_precision_quantization_config_path
mixed_precision_quantization_config_path)
def set_paddle_mkldnn(self, use_mkldnn=True): def set_paddle_mkldnn(self, use_mkldnn=True):
"""Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default. """Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default.
@@ -373,7 +373,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_openvino_device` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_device` instead." "`RuntimeOption.set_openvino_device` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_device` instead."
) )
return self._option.set_openvino_device(name) self._option.openvino_option.set_device(name)
def set_openvino_shape_info(self, shape_info): def set_openvino_shape_info(self, shape_info):
"""Set shape information of the models' inputs, used for GPU to fix the shape """Set shape information of the models' inputs, used for GPU to fix the shape
@@ -384,7 +384,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_openvino_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_shape_info` instead." "`RuntimeOption.set_openvino_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_shape_info` instead."
) )
return self._option.set_openvino_shape_info(shape_info) self._option.openvino_option.set_shape_info(shape_info)
def set_openvino_cpu_operators(self, operators): def set_openvino_cpu_operators(self, operators):
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU """While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
@@ -395,7 +395,7 @@ class RuntimeOption:
logging.warning( logging.warning(
"`RuntimeOption.set_openvino_cpu_operators` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_cpu_operators` instead." "`RuntimeOption.set_openvino_cpu_operators` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_cpu_operators` instead."
) )
return self._option.set_openvino_cpu_operators(operators) self._option.openvino_option.set_cpu_operators(operators)
def enable_paddle_log_info(self): def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
@@ -415,17 +415,26 @@ class RuntimeOption:
def enable_lite_fp16(self): def enable_lite_fp16(self):
"""Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. """Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
""" """
return self._option.enable_lite_fp16() logging.warning(
"`RuntimeOption.enable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = True` instead."
)
self._option.paddle_lite_option.enable_fp16 = True
def disable_lite_fp16(self): def disable_lite_fp16(self):
"""Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. """Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
""" """
return self._option.disable_lite_fp16() logging.warning(
"`RuntimeOption.disable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = False` instead."
)
self._option.paddle_lite_option.enable_fp16 = False
def set_lite_power_mode(self, mode): def set_lite_power_mode(self, mode):
"""Set POWER mode while using Paddle Lite backend on ARM CPU. """Set POWER mode while using Paddle Lite backend on ARM CPU.
""" """
return self._option.set_lite_power_mode(mode) logging.warning(
"`RuntimeOption.set_lite_powermode` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.power_mode = {}` instead.".
format(mode))
self._option.paddle_lite_option.power_mode = mode
def set_trt_input_shape(self, def set_trt_input_shape(self,
tensor_name, tensor_name,
@@ -439,12 +448,15 @@ class RuntimeOption:
:param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape :param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape
:param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape :param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape
""" """
logging.warning(
"`RuntimeOption.set_trt_input_shape` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.set_shape()` instead."
)
if opt_shape is None and max_shape is None: if opt_shape is None and max_shape is None:
opt_shape = min_shape opt_shape = min_shape
max_shape = min_shape max_shape = min_shape
else: else:
assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both."
return self._option.set_trt_input_shape(tensor_name, min_shape, return self._option.trt_option.set_shape(tensor_name, min_shape,
opt_shape, max_shape) opt_shape, max_shape)
def set_trt_cache_file(self, cache_file_path): def set_trt_cache_file(self, cache_file_path):
@@ -452,17 +464,26 @@ class RuntimeOption:
:param cache_file_path: (str)Path of tensorrt cache file :param cache_file_path: (str)Path of tensorrt cache file
""" """
return self._option.set_trt_cache_file(cache_file_path) logging.warning(
"`RuntimeOption.set_trt_cache_file` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.serialize_file = {}` instead.".
format(cache_file_path))
self._option.trt_option.serialize_file = cache_file_path
def enable_trt_fp16(self): def enable_trt_fp16(self):
"""Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference. """Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference.
""" """
return self._option.enable_trt_fp16() logging.warning(
"`RuntimeOption.enable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = True` instead."
)
self._option.trt_option.enable_fp16 = True
def disable_trt_fp16(self): def disable_trt_fp16(self):
"""Disable half precision inference while suing TensorRT backend. """Disable half precision inference while suing TensorRT backend.
""" """
return self._option.disable_trt_fp16() logging.warning(
"`RuntimeOption.disable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = False` instead."
)
self._option.trt_option.enable_fp16 = False
def enable_pinned_memory(self): def enable_pinned_memory(self):
"""Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend. """Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
@@ -482,12 +503,18 @@ class RuntimeOption:
def set_trt_max_workspace_size(self, trt_max_workspace_size): def set_trt_max_workspace_size(self, trt_max_workspace_size):
"""Set max workspace size while using TensorRT backend. """Set max workspace size while using TensorRT backend.
""" """
return self._option.set_trt_max_workspace_size(trt_max_workspace_size) logging.warning(
"`RuntimeOption.set_trt_max_workspace_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_workspace_size = {}` instead.".
format(trt_max_workspace_size))
self._option.trt_option.max_workspace_size = trt_max_workspace_size
def set_trt_max_batch_size(self, trt_max_batch_size): def set_trt_max_batch_size(self, trt_max_batch_size):
"""Set max batch size while using TensorRT backend. """Set max batch size while using TensorRT backend.
""" """
return self._option.set_trt_max_batch_size(trt_max_batch_size) logging.warning(
"`RuntimeOption.set_trt_max_batch_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_batch_size = {}` instead.".
format(trt_max_batch_size))
self._option.trt_option.max_batch_size = trt_max_batch_size
def enable_paddle_trt_collect_shape(self): def enable_paddle_trt_collect_shape(self):
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT """Enable collect subgraph shape information while using Paddle Inference with TensorRT
@@ -558,6 +585,14 @@ class RuntimeOption:
""" """
return self._option.ort_option return self._option.ort_option
@property
def trt_option(self):
"""Get TrtBackendOption object to configure TensorRT backend
:return TrtBackendOption
"""
return self._option.trt_option
def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50): def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50):
"""Set the profile mode as 'true'. """Set the profile mode as 'true'.
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. :param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.

View File

@@ -162,7 +162,8 @@ optimization {
gpu_execution_accelerator : [ gpu_execution_accelerator : [
{ {
name : "tensorrt" name : "tensorrt"
# Use FP16 inference in TensorRT. You can also choose: trt_fp32, trt_int8 # Use FP16 inference in TensorRT. You can also choose: trt_fp32
# If the loaded model is a quantized model, this precision will be int8 automatically
parameters { key: "precision" value: "trt_fp16" } parameters { key: "precision" value: "trt_fp16" }
} }
] ]

View File

@@ -162,7 +162,8 @@ optimization {
gpu_execution_accelerator : [ gpu_execution_accelerator : [
{ {
name : "tensorrt" name : "tensorrt"
# 使用TensorRT的FP16推理,其他可选项为: trt_fp32、trt_int8 # 使用TensorRT的FP16推理,其他可选项为: trt_fp32
# 如果加载的是量化模型此精度设置无效会默认使用int8进行推理
parameters { key: "precision" value: "trt_fp16" } parameters { key: "precision" value: "trt_fp16" }
} }
] ]

View File

@@ -168,7 +168,10 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model,
} }
ModelState::ModelState(TRITONBACKEND_Model* triton_model) ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model), model_load_(false), main_runtime_(nullptr), is_clone_(true) { : BackendModel(triton_model),
model_load_(false),
main_runtime_(nullptr),
is_clone_(true) {
// Create runtime options that will be cloned and used for each // Create runtime options that will be cloned and used for each
// instance when creating that instance's runtime. // instance when creating that instance's runtime.
runtime_options_.reset(new fastdeploy::RuntimeOption()); runtime_options_.reset(new fastdeploy::RuntimeOption());
@@ -232,7 +235,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
int num_streams; int num_streams;
THROW_IF_BACKEND_MODEL_ERROR( THROW_IF_BACKEND_MODEL_ERROR(
ParseIntValue(value_string, &num_streams)); ParseIntValue(value_string, &num_streams));
runtime_options_->SetOpenVINOStreams(num_streams); runtime_options_->openvino_option.num_streams = num_streams;
} else if (param_key == "is_clone") { } else if (param_key == "is_clone") {
THROW_IF_BACKEND_MODEL_ERROR( THROW_IF_BACKEND_MODEL_ERROR(
ParseBoolValue(value_string, &is_clone_)); ParseBoolValue(value_string, &is_clone_));
@@ -271,11 +274,11 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
std::vector<int32_t> shape; std::vector<int32_t> shape;
FDParseShape(params, input_name, &shape); FDParseShape(params, input_name, &shape);
if (name == "min_shape") { if (name == "min_shape") {
runtime_options_->trt_min_shape[input_name] = shape; runtime_options_->trt_option.min_shape[input_name] = shape;
} else if (name == "max_shape") { } else if (name == "max_shape") {
runtime_options_->trt_max_shape[input_name] = shape; runtime_options_->trt_option.max_shape[input_name] = shape;
} else { } else {
runtime_options_->trt_opt_shape[input_name] = shape; runtime_options_->trt_option.opt_shape[input_name] = shape;
} }
} }
} }
@@ -292,12 +295,10 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
std::transform(value_string.begin(), value_string.end(), std::transform(value_string.begin(), value_string.end(),
value_string.begin(), ::tolower); value_string.begin(), ::tolower);
if (value_string == "trt_fp16") { if (value_string == "trt_fp16") {
runtime_options_->EnableTrtFP16(); runtime_options_->trt_option.enable_fp16 = true;
} else if (value_string == "trt_int8") {
// TODO(liqi): use EnableTrtINT8
runtime_options_->trt_enable_int8 = true;
} else if (value_string == "pd_fp16") { } else if (value_string == "pd_fp16") {
// TODO(liqi): paddle inference don't currently have interface for fp16. // TODO(liqi): paddle inference don't currently have interface
// for fp16.
} }
// } else if( param_key == "max_batch_size") { // } else if( param_key == "max_batch_size") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
@@ -307,7 +308,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
// value_string, // value_string,
// &runtime_options_->trt_max_workspace_size)); // &runtime_options_->trt_max_workspace_size));
} else if (param_key == "cache_file") { } else if (param_key == "cache_file") {
runtime_options_->SetTrtCacheFile(value_string); runtime_options_->trt_option.serialize_file = value_string;
} else if (param_key == "use_paddle") { } else if (param_key == "use_paddle") {
runtime_options_->EnablePaddleToTrt(); runtime_options_->EnablePaddleToTrt();
} else if (param_key == "use_paddle_log") { } else if (param_key == "use_paddle_log") {
@@ -330,16 +331,16 @@ TRITONSERVER_Error* ModelState::LoadModel(
const int32_t instance_group_device_id, std::string* model_path, const int32_t instance_group_device_id, std::string* model_path,
std::string* params_path, fastdeploy::Runtime** runtime, std::string* params_path, fastdeploy::Runtime** runtime,
cudaStream_t stream) { cudaStream_t stream) {
// FastDeploy Runtime creation is not thread-safe, so multiple creations // FastDeploy Runtime creation is not thread-safe, so multiple creations
// are serialized with a global lock. // are serialized with a global lock.
// The Clone interface can be invoked only when the main_runtime_ is created. // The Clone interface can be invoked only when the main_runtime_ is created.
static std::mutex global_context_mu; static std::mutex global_context_mu;
std::lock_guard<std::mutex> glock(global_context_mu); std::lock_guard<std::mutex> glock(global_context_mu);
if(model_load_ && is_clone_) { if (model_load_ && is_clone_) {
if(main_runtime_ == nullptr) { if (main_runtime_ == nullptr) {
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_NOT_FOUND,
std::string("main_runtime is nullptr").c_str()); std::string("main_runtime is nullptr").c_str());
} }
*runtime = main_runtime_->Clone((void*)stream, instance_group_device_id); *runtime = main_runtime_->Clone((void*)stream, instance_group_device_id);
@@ -367,21 +368,21 @@ TRITONSERVER_Error* ModelState::LoadModel(
if (not exists) { if (not exists) {
return TRITONSERVER_ErrorNew( return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_NOT_FOUND, TRITONSERVER_ERROR_NOT_FOUND,
std::string("Paddle params should be named as 'model.pdiparams' or " std::string(
"Paddle params should be named as 'model.pdiparams' or "
"not provided.'") "not provided.'")
.c_str()); .c_str());
} }
runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE; runtime_options_->SetModelPath(*model_path, *params_path,
runtime_options_->model_file = *model_path; fastdeploy::ModelFormat::PADDLE);
runtime_options_->params_file = *params_path;
} else { } else {
runtime_options_->model_format = fastdeploy::ModelFormat::ONNX; runtime_options_->SetModelPath(*model_path, "",
runtime_options_->model_file = *model_path; fastdeploy::ModelFormat::ONNX);
} }
} }
// GPU // GPU
#ifdef TRITON_ENABLE_GPU #ifdef TRITON_ENABLE_GPU
if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) ||
(instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) {
runtime_options_->UseGpu(instance_group_device_id); runtime_options_->UseGpu(instance_group_device_id);
@@ -389,12 +390,12 @@ TRITONSERVER_Error* ModelState::LoadModel(
} else if (runtime_options_->device != fastdeploy::Device::IPU) { } else if (runtime_options_->device != fastdeploy::Device::IPU) {
runtime_options_->UseCpu(); runtime_options_->UseCpu();
} }
#else #else
if (runtime_options_->device != fastdeploy::Device::IPU) { if (runtime_options_->device != fastdeploy::Device::IPU) {
// If Device is set to IPU, just skip CPU setting. // If Device is set to IPU, just skip CPU setting.
runtime_options_->UseCpu(); runtime_options_->UseCpu();
} }
#endif // TRITON_ENABLE_GPU #endif // TRITON_ENABLE_GPU
*runtime = main_runtime_ = new fastdeploy::Runtime(); *runtime = main_runtime_ = new fastdeploy::Runtime();
if (!(*runtime)->Init(*runtime_options_)) { if (!(*runtime)->Init(*runtime_options_)) {
@@ -1074,8 +1075,7 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
{TRITONSERVER_MEMORY_CPU, 0}}; {TRITONSERVER_MEMORY_CPU, 0}};
} }
RETURN_IF_ERROR( RETURN_IF_ERROR(collector->ProcessTensor(
collector->ProcessTensor(
input_name, nullptr, 0, allowed_input_types, &input_buffer, input_name, nullptr, 0, allowed_input_types, &input_buffer,
&batchn_byte_size, &memory_type, &memory_type_id)); &batchn_byte_size, &memory_type, &memory_type_id));
@@ -1089,9 +1089,9 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
} }
fastdeploy::FDTensor fdtensor(in_name); fastdeploy::FDTensor fdtensor(in_name);
fdtensor.SetExternalData( fdtensor.SetExternalData(batchn_shape, ConvertDataTypeToFD(input_datatype),
batchn_shape, ConvertDataTypeToFD(input_datatype), const_cast<char*>(input_buffer), device,
const_cast<char*>(input_buffer), device, device_id); device_id);
runtime_->BindInputTensor(in_name, fdtensor); runtime_->BindInputTensor(in_name, fdtensor);
} }
@@ -1130,23 +1130,22 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors(
for (auto& output_name : output_names_) { for (auto& output_name : output_names_) {
auto* output_tensor = runtime_->GetOutputTensor(output_name); auto* output_tensor = runtime_->GetOutputTensor(output_name);
if (output_tensor == nullptr) { if (output_tensor == nullptr) {
RETURN_IF_ERROR( RETURN_IF_ERROR(TRITONSERVER_ErrorNew(
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, TRITONSERVER_ERROR_INTERNAL,
(std::string("output tensor '") + output_name + "' is not found") (std::string("output tensor '") + output_name + "' is not found")
.c_str())); .c_str()));
} }
TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU;
int64_t memory_type_id = 0; int64_t memory_type_id = 0;
if(output_tensor->device == fastdeploy::Device::GPU) { if (output_tensor->device == fastdeploy::Device::GPU) {
memory_type = TRITONSERVER_MEMORY_GPU; memory_type = TRITONSERVER_MEMORY_GPU;
memory_type_id = DeviceId(); memory_type_id = DeviceId();
} }
responder.ProcessTensor( responder.ProcessTensor(
output_tensor->name, ConvertFDType(output_tensor->dtype), output_tensor->name, ConvertFDType(output_tensor->dtype),
output_tensor->shape, output_tensor->shape,
reinterpret_cast<char*>(output_tensor->MutableData()), reinterpret_cast<char*>(output_tensor->MutableData()), memory_type,
memory_type, memory_type_id); memory_type_id);
} }
// Finalize and wait for any pending buffer copies. // Finalize and wait for any pending buffer copies.