From 713afe7f1cc91fe86607499e66ae09bf1a34f452 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 7 Feb 2023 17:57:46 +0800 Subject: [PATCH] [Other] Deprecate some option api and parameters (#1243) * Optimize Poros backend * fix error * Add more pybind * fix conflicts * add some deprecate notices * [Other] Deprecate some apis in RuntimeOption (#1240) * Deprecate more options * modify serving * Update option.h * fix tensorrt error * Update option_pybind.cc * Update option_pybind.cc * Fix error in serving * fix word spell error --- fastdeploy/runtime/backends/lite/option.h | 28 ++-- .../runtime/backends/lite/option_pybind.cc | 1 - fastdeploy/runtime/backends/openvino/option.h | 4 + fastdeploy/runtime/backends/ort/option.h | 24 ++- fastdeploy/runtime/backends/poros/option.h | 2 + fastdeploy/runtime/backends/tensorrt/option.h | 61 ++++++-- .../backends/tensorrt/option_pybind.cc | 31 ++++ fastdeploy/runtime/option_pybind.cc | 38 +---- fastdeploy/runtime/runtime.cc | 55 +++---- fastdeploy/runtime/runtime_option.cc | 138 ++++++++++++++---- fastdeploy/runtime/runtime_option.h | 33 ++--- python/fastdeploy/runtime.py | 81 +++++++--- serving/docs/EN/model_configuration-en.md | 5 +- serving/docs/zh_CN/model_configuration.md | 3 +- serving/src/fastdeploy_runtime.cc | 105 +++++++------ 15 files changed, 380 insertions(+), 229 deletions(-) create mode 100644 fastdeploy/runtime/backends/tensorrt/option_pybind.cc diff --git a/fastdeploy/runtime/backends/lite/option.h b/fastdeploy/runtime/backends/lite/option.h index 1ffd01385..70781d80f 100755 --- a/fastdeploy/runtime/backends/lite/option.h +++ b/fastdeploy/runtime/backends/lite/option.h @@ -48,6 +48,8 @@ enum LitePowerMode { LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode }; +/*! @brief Option object to configure Paddle Lite backend + */ struct LiteBackendOption { /// Paddle Lite power mode for mobile device. int power_mode = 3; @@ -55,12 +57,20 @@ struct LiteBackendOption { int cpu_threads = 1; /// Enable use half precision bool enable_fp16 = false; - /// Enable use int8 precision for quantized model - bool enable_int8 = false; - + /// Inference device, Paddle Lite support CPU/KUNLUNXIN/TIMVX/ASCEND Device device = Device::CPU; + /// Index of inference device + int device_id = 0; - // optimized model dir for CxxConfig + int kunlunxin_l3_workspace_size = 0xfffc00; + bool kunlunxin_locked = false; + bool kunlunxin_autotune = true; + std::string kunlunxin_autotune_file = ""; + std::string kunlunxin_precision = "int16"; + bool kunlunxin_adaptive_seqlen = false; + bool kunlunxin_enable_multi_stream = false; + + /// Optimized model dir for CxxConfig std::string optimized_model_dir = ""; std::string nnadapter_subgraph_partition_config_path = ""; std::string nnadapter_subgraph_partition_config_buffer = ""; @@ -70,13 +80,5 @@ struct LiteBackendOption { std::map>> nnadapter_dynamic_shape_info = {{"", {{0}}}}; std::vector nnadapter_device_names = {}; - int device_id = 0; - int kunlunxin_l3_workspace_size = 0xfffc00; - bool kunlunxin_locked = false; - bool kunlunxin_autotune = true; - std::string kunlunxin_autotune_file = ""; - std::string kunlunxin_precision = "int16"; - bool kunlunxin_adaptive_seqlen = false; - bool kunlunxin_enable_multi_stream = false; }; -} // namespace fastdeploy \ No newline at end of file +} // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/lite/option_pybind.cc b/fastdeploy/runtime/backends/lite/option_pybind.cc index 543255aaf..0a01854ad 100644 --- a/fastdeploy/runtime/backends/lite/option_pybind.cc +++ b/fastdeploy/runtime/backends/lite/option_pybind.cc @@ -23,7 +23,6 @@ void BindLiteOption(pybind11::module& m) { .def_readwrite("power_mode", &LiteBackendOption::power_mode) .def_readwrite("cpu_threads", &LiteBackendOption::cpu_threads) .def_readwrite("enable_fp16", &LiteBackendOption::enable_fp16) - .def_readwrite("enable_int8", &LiteBackendOption::enable_int8) .def_readwrite("device", &LiteBackendOption::device) .def_readwrite("optimized_model_dir", &LiteBackendOption::optimized_model_dir) diff --git a/fastdeploy/runtime/backends/openvino/option.h b/fastdeploy/runtime/backends/openvino/option.h index a7ad2cea0..1200bd9c7 100644 --- a/fastdeploy/runtime/backends/openvino/option.h +++ b/fastdeploy/runtime/backends/openvino/option.h @@ -23,9 +23,13 @@ #include namespace fastdeploy { +/*! @brief Option object to configure OpenVINO backend + */ struct OpenVINOBackendOption { std::string device = "CPU"; int cpu_thread_num = -1; + + /// Number of streams while use OpenVINO int num_streams = 0; /** diff --git a/fastdeploy/runtime/backends/ort/option.h b/fastdeploy/runtime/backends/ort/option.h index ca4d3254c..9487e5da9 100644 --- a/fastdeploy/runtime/backends/ort/option.h +++ b/fastdeploy/runtime/backends/ort/option.h @@ -22,20 +22,30 @@ #include namespace fastdeploy { +/*! @brief Option object to configure ONNX Runtime backend + */ struct OrtBackendOption { - // -1 means default - // 0: ORT_DISABLE_ALL - // 1: ORT_ENABLE_BASIC - // 2: ORT_ENABLE_EXTENDED - // 99: ORT_ENABLE_ALL (enable some custom optimizations e.g bert) + /* + * @brief Level of graph optimization, -1: mean default(Enable all the optimization strategy)/0: disable all the optimization strategy/1: enable basic strategy/2:enable extend strategy/99: enable all + */ int graph_optimization_level = -1; + /* + * @brief Number of threads to execute the operator, -1: default + */ int intra_op_num_threads = -1; + /* + * @brief Number of threads to execute the graph, -1: default. This parameter only will bring effects while the `OrtBackendOption::execution_mode` set to 1. + */ int inter_op_num_threads = -1; - // 0: ORT_SEQUENTIAL - // 1: ORT_PARALLEL + /* + * @brief Execution mode for the graph, -1: default(Sequential mode)/0: Sequential mode, execute the operators in graph one by one. /1: Parallel mode, execute the operators in graph parallelly. + */ int execution_mode = -1; + /// Inference device, OrtBackend supports CPU/GPU Device device = Device::CPU; + /// Inference device id int device_id = 0; + void* external_stream_ = nullptr; }; } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/poros/option.h b/fastdeploy/runtime/backends/poros/option.h index 22f0d371b..ebaffec09 100755 --- a/fastdeploy/runtime/backends/poros/option.h +++ b/fastdeploy/runtime/backends/poros/option.h @@ -22,6 +22,8 @@ namespace fastdeploy { +/*! @brief Option object to configure Poros backend + */ struct PorosBackendOption { Device device = Device::CPU; int device_id = 0; diff --git a/fastdeploy/runtime/backends/tensorrt/option.h b/fastdeploy/runtime/backends/tensorrt/option.h index 8d4ad4aaf..5cee0a7e3 100755 --- a/fastdeploy/runtime/backends/tensorrt/option.h +++ b/fastdeploy/runtime/backends/tensorrt/option.h @@ -21,23 +21,64 @@ namespace fastdeploy { +/*! @brief Option object to configure TensorRT backend + */ struct TrtBackendOption { - std::string model_file = ""; // Path of model file - std::string params_file = ""; // Path of parameters file, can be empty - - // format of input model - ModelFormat model_format = ModelFormat::AUTOREC; - - int gpu_id = 0; - bool enable_fp16 = false; - bool enable_int8 = false; + /// `max_batch_size`, it's deprecated in TensorRT 8.x size_t max_batch_size = 32; + + /// `max_workspace_size` for TensorRT size_t max_workspace_size = 1 << 30; + + /* + * @brief Enable half precison inference, on some device not support half precision, it will fallback to float32 mode + */ + bool enable_fp16 = false; + + /** \brief Set shape range of input tensor for the model that contain dynamic input shape while using TensorRT backend + * + * \param[in] tensor_name The name of input for the model which is dynamic shape + * \param[in] min The minimal shape for the input tensor + * \param[in] opt The optimized shape for the input tensor, just set the most common shape, if set as default value, it will keep same with min_shape + * \param[in] max The maximum shape for the input tensor, if set as default value, it will keep same with min_shape + */ + void SetShape(const std::string& tensor_name, + const std::vector& min, + const std::vector& opt, + const std::vector& max) { + min_shape[tensor_name].clear(); + max_shape[tensor_name].clear(); + opt_shape[tensor_name].clear(); + min_shape[tensor_name].assign(min.begin(), min.end()); + if (opt.size() == 0) { + opt_shape[tensor_name].assign(min.begin(), min.end()); + } else { + opt_shape[tensor_name].assign(opt.begin(), opt.end()); + } + if (max.size() == 0) { + max_shape[tensor_name].assign(min.begin(), min.end()); + } else { + max_shape[tensor_name].assign(max.begin(), max.end()); + } + } + /** + * @brief Set cache file path while use TensorRT backend. Loadding a Paddle/ONNX model and initialize TensorRT will take a long time, by this interface it will save the tensorrt engine to `cache_file_path`, and load it directly while execute the code again + */ + std::string serialize_file = ""; + + // The below parameters may be removed in next version, please do not + // visit or use them directly std::map> max_shape; std::map> min_shape; std::map> opt_shape; - std::string serialize_file = ""; bool enable_pinned_memory = false; void* external_stream_ = nullptr; + int gpu_id = 0; + std::string model_file = ""; // Path of model file + std::string params_file = ""; // Path of parameters file, can be empty + // format of input model + ModelFormat model_format = ModelFormat::AUTOREC; }; + + } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/tensorrt/option_pybind.cc b/fastdeploy/runtime/backends/tensorrt/option_pybind.cc new file mode 100644 index 000000000..d781256a5 --- /dev/null +++ b/fastdeploy/runtime/backends/tensorrt/option_pybind.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" +#include "fastdeploy/runtime/backends/tensorrt/option.h" + +namespace fastdeploy { + +void BindTrtOption(pybind11::module& m) { + pybind11::class_(m, "TrtBackendOption") + .def(pybind11::init()) + .def_readwrite("enable_fp16", &TrtBackendOption::enable_fp16) + .def_readwrite("max_batch_size", &TrtBackendOption::max_batch_size) + .def_readwrite("max_workspace_size", + &TrtBackendOption::max_workspace_size) + .def_readwrite("serialize_file", &TrtBackendOption::serialize_file) + .def("set_shape", &TrtBackendOption::SetShape); +} + +} // namespace fastdeploy diff --git a/fastdeploy/runtime/option_pybind.cc b/fastdeploy/runtime/option_pybind.cc index 982d18053..1dcc9acbc 100644 --- a/fastdeploy/runtime/option_pybind.cc +++ b/fastdeploy/runtime/option_pybind.cc @@ -19,12 +19,14 @@ namespace fastdeploy { void BindLiteOption(pybind11::module& m); void BindOpenVINOOption(pybind11::module& m); void BindOrtOption(pybind11::module& m); +void BindTrtOption(pybind11::module& m); void BindPorosOption(pybind11::module& m); void BindOption(pybind11::module& m) { BindLiteOption(m); BindOpenVINOOption(m); BindOrtOption(m); + BindTrtOption(m); BindPorosOption(m); pybind11::class_(m, "RuntimeOption") @@ -40,47 +42,22 @@ void BindOption(pybind11::module& m) { .def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option) .def_readwrite("openvino_option", &RuntimeOption::openvino_option) .def_readwrite("ort_option", &RuntimeOption::ort_option) + .def_readwrite("trt_option", &RuntimeOption::trt_option) .def_readwrite("poros_option", &RuntimeOption::poros_option) .def("set_external_stream", &RuntimeOption::SetExternalStream) .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_poros_backend", &RuntimeOption::UsePorosBackend) .def("use_ort_backend", &RuntimeOption::UseOrtBackend) - .def("set_ort_graph_opt_level", &RuntimeOption::SetOrtGraphOptLevel) .def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend) - .def("set_lite_device_names", &RuntimeOption::SetLiteDeviceNames) - .def("set_lite_context_properties", - &RuntimeOption::SetLiteContextProperties) - .def("set_lite_model_cache_dir", &RuntimeOption::SetLiteModelCacheDir) - .def("set_lite_dynamic_shape_info", - &RuntimeOption::SetLiteDynamicShapeInfo) - .def("set_lite_subgraph_partition_path", - &RuntimeOption::SetLiteSubgraphPartitionPath) - .def("set_lite_mixed_precision_quantization_config_path", - &RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath) - .def("set_lite_subgraph_partition_config_buffer", - &RuntimeOption::SetLiteSubgraphPartitionConfigBuffer) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) - .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) - .def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo) - .def("set_openvino_cpu_operators", - &RuntimeOption::SetOpenVINOCpuOperators) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("set_paddle_mkldnn_cache_size", &RuntimeOption::SetPaddleMKLDNNCacheSize) - .def("enable_lite_fp16", &RuntimeOption::EnableLiteFP16) - .def("disable_lite_fp16", &RuntimeOption::DisableLiteFP16) - .def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode) - .def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape) - .def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize) - .def("set_trt_max_batch_size", &RuntimeOption::SetTrtMaxBatchSize) .def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt) - .def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16) - .def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16) - .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) .def("enable_paddle_trt_collect_shape", @@ -103,15 +80,6 @@ void BindOption(pybind11::module& m) { .def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num) .def_readwrite("device_id", &RuntimeOption::device_id) .def_readwrite("device", &RuntimeOption::device) - .def_readwrite("trt_max_shape", &RuntimeOption::trt_max_shape) - .def_readwrite("trt_opt_shape", &RuntimeOption::trt_opt_shape) - .def_readwrite("trt_min_shape", &RuntimeOption::trt_min_shape) - .def_readwrite("trt_serialize_file", &RuntimeOption::trt_serialize_file) - .def_readwrite("trt_enable_fp16", &RuntimeOption::trt_enable_fp16) - .def_readwrite("trt_enable_int8", &RuntimeOption::trt_enable_int8) - .def_readwrite("trt_max_batch_size", &RuntimeOption::trt_max_batch_size) - .def_readwrite("trt_max_workspace_size", - &RuntimeOption::trt_max_workspace_size) .def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num) .def_readwrite("ipu_micro_batch_size", &RuntimeOption::ipu_micro_batch_size) diff --git a/fastdeploy/runtime/runtime.cc b/fastdeploy/runtime/runtime.cc index e7db79127..e6bd14456 100644 --- a/fastdeploy/runtime/runtime.cc +++ b/fastdeploy/runtime/runtime.cc @@ -244,17 +244,9 @@ void Runtime::CreatePaddleBackend() { if (pd_option.use_gpu && option.pd_enable_trt) { pd_option.enable_trt = true; pd_option.collect_shape = option.pd_collect_shape; - auto trt_option = TrtBackendOption(); - trt_option.gpu_id = option.device_id; - trt_option.enable_fp16 = option.trt_enable_fp16; - trt_option.max_batch_size = option.trt_max_batch_size; - trt_option.max_workspace_size = option.trt_max_workspace_size; - trt_option.max_shape = option.trt_max_shape; - trt_option.min_shape = option.trt_min_shape; - trt_option.opt_shape = option.trt_opt_shape; - trt_option.serialize_file = option.trt_serialize_file; - trt_option.enable_pinned_memory = option.enable_pinned_memory; - pd_option.trt_option = trt_option; + pd_option.trt_option = option.trt_option; + pd_option.trt_option.gpu_id = option.device_id; + pd_option.trt_option.enable_pinned_memory = option.enable_pinned_memory; pd_option.trt_disabled_ops_ = option.trt_disabled_ops_; } #endif @@ -339,41 +331,33 @@ void Runtime::CreateTrtBackend() { "TrtBackend only support model format of ModelFormat::PADDLE / " "ModelFormat::ONNX."); #ifdef ENABLE_TRT_BACKEND - auto trt_option = TrtBackendOption(); - trt_option.model_file = option.model_file; - trt_option.params_file = option.params_file; - trt_option.model_format = option.model_format; - trt_option.gpu_id = option.device_id; - trt_option.enable_fp16 = option.trt_enable_fp16; - trt_option.enable_int8 = option.trt_enable_int8; - trt_option.max_batch_size = option.trt_max_batch_size; - trt_option.max_workspace_size = option.trt_max_workspace_size; - trt_option.max_shape = option.trt_max_shape; - trt_option.min_shape = option.trt_min_shape; - trt_option.opt_shape = option.trt_opt_shape; - trt_option.serialize_file = option.trt_serialize_file; - trt_option.enable_pinned_memory = option.enable_pinned_memory; - trt_option.external_stream_ = option.external_stream_; + option.trt_option.model_file = option.model_file; + option.trt_option.params_file = option.params_file; + option.trt_option.model_format = option.model_format; + option.trt_option.gpu_id = option.device_id; + option.trt_option.enable_pinned_memory = option.enable_pinned_memory; + option.trt_option.external_stream_ = option.external_stream_; backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); casted_backend->benchmark_option_ = option.benchmark_option; if (option.model_format == ModelFormat::ONNX) { if (option.model_from_memory_) { - FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option), - "Load model from ONNX failed while initliazing TrtBackend."); + FDASSERT( + casted_backend->InitFromOnnx(option.model_file, option.trt_option), + "Load model from ONNX failed while initliazing TrtBackend."); ReleaseModelMemoryBuffer(); } else { std::string model_buffer = ""; FDASSERT(ReadBinaryFromFile(option.model_file, &model_buffer), "Fail to read binary from model file"); - FDASSERT(casted_backend->InitFromOnnx(model_buffer, trt_option), + FDASSERT(casted_backend->InitFromOnnx(model_buffer, option.trt_option), "Load model from ONNX failed while initliazing TrtBackend."); } } else { if (option.model_from_memory_) { - FDASSERT(casted_backend->InitFromPaddle(option.model_file, - option.params_file, trt_option), + FDASSERT(casted_backend->InitFromPaddle( + option.model_file, option.params_file, option.trt_option), "Load model from Paddle failed while initliazing TrtBackend."); ReleaseModelMemoryBuffer(); } else { @@ -384,7 +368,7 @@ void Runtime::CreateTrtBackend() { FDASSERT(ReadBinaryFromFile(option.params_file, ¶ms_buffer), "Fail to read binary from parameter file"); FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer, - trt_option), + option.trt_option), "Load model from Paddle failed while initliazing TrtBackend."); } } @@ -505,9 +489,10 @@ bool Runtime::Compile(std::vector>& prewarm_tensors, } option.poros_option.device = option.device; option.poros_option.device_id = option.device_id; - option.poros_option.enable_fp16 = option.trt_enable_fp16; - option.poros_option.max_batch_size = option.trt_max_batch_size; - option.poros_option.max_workspace_size = option.trt_max_workspace_size; + option.poros_option.enable_fp16 = option.trt_option.enable_fp16; + option.poros_option.max_batch_size = option.trt_option.max_batch_size; + option.poros_option.max_workspace_size = option.trt_option.max_workspace_size; + backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); FDASSERT( diff --git a/fastdeploy/runtime/runtime_option.cc b/fastdeploy/runtime/runtime_option.cc index c262a211c..c9ab487a1 100644 --- a/fastdeploy/runtime/runtime_option.cc +++ b/fastdeploy/runtime/runtime_option.cc @@ -102,6 +102,10 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) { } void RuntimeOption::SetOrtGraphOptLevel(int level) { + FDWARNING << "`RuntimeOption::SetOrtGraphOptLevel` will be removed in " + "v1.2.0, please modify its member variables directly, e.g " + "`runtime_option.ort_option.graph_optimization_level = 99`." + << std::endl; std::vector supported_level{-1, 0, 1, 2}; auto valid_level = std::find(supported_level.begin(), supported_level.end(), level) != supported_level.end(); @@ -203,67 +207,127 @@ void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) { } void RuntimeOption::SetOpenVINODevice(const std::string& name) { - openvino_option.device = name; + FDWARNING << "`RuntimeOption::SetOpenVINODevice` will be removed in v1.2.0, " + "please use `RuntimeOption.openvino_option.SetDeivce(const " + "std::string&)` instead." + << std::endl; + openvino_option.SetDevice(name); } -void RuntimeOption::EnableLiteFP16() { paddle_lite_option.enable_fp16 = true; } +void RuntimeOption::EnableLiteFP16() { + FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, " + "please modify its member variables directly, e.g " + "`runtime_option.paddle_lite_option.enable_fp16 = true`" + << std::endl; + paddle_lite_option.enable_fp16 = true; +} void RuntimeOption::DisableLiteFP16() { + FDWARNING << "`RuntimeOption::EnableLiteFP16` will be removed in v1.2.0, " + "please modify its member variables directly, e.g " + "`runtime_option.paddle_lite_option.enable_fp16 = false`" + << std::endl; paddle_lite_option.enable_fp16 = false; } -void RuntimeOption::EnableLiteInt8() { paddle_lite_option.enable_int8 = true; } +void RuntimeOption::EnableLiteInt8() { + FDWARNING << "RuntimeOption::EnableLiteInt8 is a useless api, this calling " + "will not bring any effects, and will be removed in v1.2.0. if " + "you load a quantized model, it will automatically run with " + "int8 mode; otherwise it will run with float mode." + << std::endl; +} void RuntimeOption::DisableLiteInt8() { - paddle_lite_option.enable_int8 = false; + FDWARNING << "RuntimeOption::DisableLiteInt8 is a useless api, this calling " + "will not bring any effects, and will be removed in v1.2.0. if " + "you load a quantized model, it will automatically run with " + "int8 mode; otherwise it will run with float mode." + << std::endl; } void RuntimeOption::SetLitePowerMode(LitePowerMode mode) { + FDWARNING << "`RuntimeOption::SetLitePowerMode` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.power_mode = 3;`" + << std::endl; paddle_lite_option.power_mode = mode; } void RuntimeOption::SetLiteOptimizedModelDir( const std::string& optimized_model_dir) { + FDWARNING + << "`RuntimeOption::SetLiteOptimizedModelDir` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.optimized_model_dir = \"...\"`" + << std::endl; paddle_lite_option.optimized_model_dir = optimized_model_dir; } void RuntimeOption::SetLiteSubgraphPartitionPath( const std::string& nnadapter_subgraph_partition_config_path) { + FDWARNING << "`RuntimeOption::SetLiteSubgraphPartitionPath` will be removed " + "in v1.2.0, please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.nnadapter_subgraph_" + "partition_config_path = \"...\";` " + << std::endl; paddle_lite_option.nnadapter_subgraph_partition_config_path = nnadapter_subgraph_partition_config_path; } void RuntimeOption::SetLiteSubgraphPartitionConfigBuffer( const std::string& nnadapter_subgraph_partition_config_buffer) { + FDWARNING + << "`RuntimeOption::SetLiteSubgraphPartitionConfigBuffer` will be " + "removed in v1.2.0, please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.nnadapter_subgraph_partition_" + "config_buffer = ...`" + << std::endl; paddle_lite_option.nnadapter_subgraph_partition_config_buffer = nnadapter_subgraph_partition_config_buffer; } -void RuntimeOption::SetLiteDeviceNames( - const std::vector& nnadapter_device_names) { - paddle_lite_option.nnadapter_device_names = nnadapter_device_names; -} - void RuntimeOption::SetLiteContextProperties( const std::string& nnadapter_context_properties) { + FDWARNING << "`RuntimeOption::SetLiteContextProperties` will be removed in " + "v1.2.0, please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.nnadapter_context_" + "properties = ...`" + << std::endl; paddle_lite_option.nnadapter_context_properties = nnadapter_context_properties; } void RuntimeOption::SetLiteModelCacheDir( const std::string& nnadapter_model_cache_dir) { + FDWARNING + << "`RuntimeOption::SetLiteModelCacheDir` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.nnadapter_model_cache_dir = ...`" + << std::endl; paddle_lite_option.nnadapter_model_cache_dir = nnadapter_model_cache_dir; } void RuntimeOption::SetLiteDynamicShapeInfo( const std::map>>& nnadapter_dynamic_shape_info) { + FDWARNING << "`RuntimeOption::SetLiteDynamicShapeInfo` will be removed in " + "v1.2.0, please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.paddle_lite_option." + "nnadapter_dynamic_shape_info = ...`" + << std::endl; paddle_lite_option.nnadapter_dynamic_shape_info = nnadapter_dynamic_shape_info; } void RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath( const std::string& nnadapter_mixed_precision_quantization_config_path) { + FDWARNING + << "`RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath` will be " + "removed in v1.2.0, please modify its member variable directly, e.g " + "`runtime_option.paddle_lite_option.paddle_lite_option.nnadapter_" + "mixed_precision_quantization_config_path = ...`" + << std::endl; paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = nnadapter_mixed_precision_quantization_config_path; } @@ -272,42 +336,60 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name, const std::vector& min_shape, const std::vector& opt_shape, const std::vector& max_shape) { - trt_min_shape[input_name].clear(); - trt_max_shape[input_name].clear(); - trt_opt_shape[input_name].clear(); - trt_min_shape[input_name].assign(min_shape.begin(), min_shape.end()); - if (opt_shape.size() == 0) { - trt_opt_shape[input_name].assign(min_shape.begin(), min_shape.end()); - } else { - trt_opt_shape[input_name].assign(opt_shape.begin(), opt_shape.end()); - } - if (max_shape.size() == 0) { - trt_max_shape[input_name].assign(min_shape.begin(), min_shape.end()); - } else { - trt_max_shape[input_name].assign(max_shape.begin(), max_shape.end()); - } + FDWARNING << "`RuntimeOption::SetTrtInputShape` will be removed in v1.2.0, " + "please use `RuntimeOption.trt_option.SetShape()` instead." + << std::endl; + trt_option.SetShape(input_name, min_shape, opt_shape, max_shape); } void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) { - trt_max_workspace_size = max_workspace_size; + FDWARNING << "`RuntimeOption::SetTrtMaxWorkspaceSize` will be removed in " + "v1.2.0, please modify its member variable directly, e.g " + "`RuntimeOption.trt_option.max_workspace_size = " + << max_workspace_size << "`." << std::endl; + trt_option.max_workspace_size = max_workspace_size; } void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) { - trt_max_batch_size = max_batch_size; + FDWARNING << "`RuntimeOption::SetTrtMaxBatchSize` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`RuntimeOption.trt_option.max_batch_size = " + << max_batch_size << "`." << std::endl; + trt_option.max_batch_size = max_batch_size; } -void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } +void RuntimeOption::EnableTrtFP16() { + FDWARNING << "`RuntimeOption::EnableTrtFP16` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.trt_option.enable_fp16 = true;`" + << std::endl; + trt_option.enable_fp16 = true; +} -void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; } +void RuntimeOption::DisableTrtFP16() { + FDWARNING << "`RuntimeOption::DisableTrtFP16` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.trt_option.enable_fp16 = false;`" + << std::endl; + trt_option.enable_fp16 = false; +} void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; } void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; } void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { - trt_serialize_file = cache_file_path; + FDWARNING << "`RuntimeOption::SetTrtCacheFile` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.trt_option.serialize_file = \"" + << cache_file_path << "\"." << std::endl; + trt_option.serialize_file = cache_file_path; } void RuntimeOption::SetOpenVINOStreams(int num_streams) { + FDWARNING << "`RuntimeOption::SetOpenVINOStreams` will be removed in v1.2.0, " + "please modify its member variable directly, e.g " + "`runtime_option.openvino_option.num_streams = " + << num_streams << "`." << std::endl; openvino_option.num_streams = num_streams; } diff --git a/fastdeploy/runtime/runtime_option.h b/fastdeploy/runtime/runtime_option.h index 44b81e50a..6fb7e78e7 100644 --- a/fastdeploy/runtime/runtime_option.h +++ b/fastdeploy/runtime/runtime_option.h @@ -211,12 +211,6 @@ struct FASTDEPLOY_DECL RuntimeOption { void SetLiteSubgraphPartitionConfigBuffer( const std::string& nnadapter_subgraph_partition_config_buffer); - /** - * @brief Set device name for Paddle Lite backend. - */ - void - SetLiteDeviceNames(const std::vector& nnadapter_device_names); - /** * @brief Set context properties for Paddle Lite backend. */ @@ -347,21 +341,21 @@ struct FASTDEPLOY_DECL RuntimeOption { void SetIpuConfig(bool enable_fp16 = false, int replica_num = 1, float available_memory_proportion = 1.0, bool enable_half_partial = false); - + /** \brief Set the profile mode as 'true'. * * \param[in] inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. * \param[in] repeat Repeat times for runtime inference. * \param[in] warmup Warmup times for runtime inference. */ - void EnableProfiling(bool inclue_h2d_d2h = false, + void EnableProfiling(bool inclue_h2d_d2h = false, int repeat = 100, int warmup = 50) { benchmark_option.enable_profile = true; benchmark_option.warmup = warmup; benchmark_option.repeats = repeat; benchmark_option.include_h2d_d2h = inclue_h2d_d2h; } - + /** \brief Set the profile mode as 'false'. */ void DisableProfiling() { @@ -381,6 +375,7 @@ struct FASTDEPLOY_DECL RuntimeOption { bool enable_pinned_memory = false; + /// Option to configure ONNX Runtime backend OrtBackendOption ort_option; // ======Only for Paddle Backend===== @@ -401,20 +396,16 @@ struct FASTDEPLOY_DECL RuntimeOption { float ipu_available_memory_proportion = 1.0; bool ipu_enable_half_partial = false; - // ======Only for Trt Backend======= - std::map> trt_max_shape; - std::map> trt_min_shape; - std::map> trt_opt_shape; - std::string trt_serialize_file = ""; - bool trt_enable_fp16 = false; - bool trt_enable_int8 = false; - size_t trt_max_batch_size = 1; - size_t trt_max_workspace_size = 1 << 30; + /// Option to configure TensorRT backend + TrtBackendOption trt_option; + // ======Only for PaddleTrt Backend======= std::vector trt_disabled_ops_{}; + /// Option to configure Poros backend PorosBackendOption poros_option; + /// Option to configure OpenVINO backend OpenVINOBackendOption openvino_option; // ======Only for RKNPU2 Backend======= @@ -433,11 +424,11 @@ struct FASTDEPLOY_DECL RuntimeOption { std::string model_file = ""; std::string params_file = ""; bool model_from_memory_ = false; - // format of input model + /// format of input model ModelFormat model_format = ModelFormat::PADDLE; - // Benchmark option - benchmark::BenchmarkOption benchmark_option; + /// Benchmark option + benchmark::BenchmarkOption benchmark_option; }; } // namespace fastdeploy diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 6c8c53cb8..2ae70202d 100644 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -154,6 +154,8 @@ class RuntimeOption: """Options for FastDeploy Runtime. """ + __slots__ = ["_option"] + def __init__(self): """Initialize a FastDeploy RuntimeOption object. """ @@ -266,7 +268,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_ort_graph_opt_level` will be deprecated in v1.2.0, please use `RuntimeOption.graph_optimize_level = 99` instead." ) - return self._option.set_ort_graph_opt_level(level) + self._option.ort_option.graph_optimize_level = level def use_paddle_backend(self): """Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU. @@ -314,7 +316,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_context_properties` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_context_properties = ...` instead." ) - return self._option.set_lite_context_properties(context_properties) + self._option.paddle_lite_option.nnadapter_context_properties = context_properties def set_lite_model_cache_dir(self, model_cache_dir): """Set nnadapter model cache dir for Paddle Lite backend. @@ -322,7 +324,8 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_model_cache_dir` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_model_cache_dir = ...` instead." ) - return self._option.set_lite_model_cache_dir(model_cache_dir) + + self._option.paddle_lite_option.nnadapter_model_cache_dir = model_cache_dir def set_lite_dynamic_shape_info(self, dynamic_shape_info): """ Set nnadapter dynamic shape info for Paddle Lite backend. @@ -330,7 +333,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_dynamic_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_dynamic_shape_info = ...` instead." ) - return self._option.set_lite_dynamic_shape_info(dynamic_shape_info) + self._option.paddle_lite_option.nnadapter_dynamic_shape_info = dynamic_shape_info def set_lite_subgraph_partition_path(self, subgraph_partition_path): """ Set nnadapter subgraph partition path for Paddle Lite backend. @@ -338,8 +341,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_subgraph_partition_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_path = ...` instead." ) - return self._option.set_lite_subgraph_partition_path( - subgraph_partition_path) + self._option.paddle_lite_option.nnadapter_subgraph_partition_config_path = subgraph_partition_path def set_lite_subgraph_partition_config_buffer(self, subgraph_partition_buffer): @@ -348,8 +350,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_subgraph_partition_buffer` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = ...` instead." ) - return self._option.set_lite_subgraph_partition_config_buffer( - subgraph_partition_buffer) + self._option.paddle_lite_option.nnadapter_subgraph_partition_config_buffer = subgraph_partition_buffer def set_lite_mixed_precision_quantization_config_path( self, mixed_precision_quantization_config_path): @@ -358,8 +359,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_lite_mixed_precision_quantization_config_path` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = ...` instead." ) - return self._option.set_lite_mixed_precision_quantization_config_path( - mixed_precision_quantization_config_path) + self._option.paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = mixed_precision_quantization_config_path def set_paddle_mkldnn(self, use_mkldnn=True): """Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default. @@ -373,7 +373,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_openvino_device` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_device` instead." ) - return self._option.set_openvino_device(name) + self._option.openvino_option.set_device(name) def set_openvino_shape_info(self, shape_info): """Set shape information of the models' inputs, used for GPU to fix the shape @@ -384,7 +384,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_openvino_shape_info` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_shape_info` instead." ) - return self._option.set_openvino_shape_info(shape_info) + self._option.openvino_option.set_shape_info(shape_info) def set_openvino_cpu_operators(self, operators): """While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU @@ -395,7 +395,7 @@ class RuntimeOption: logging.warning( "`RuntimeOption.set_openvino_cpu_operators` will be deprecated in v1.2.0, please use `RuntimeOption.openvino_option.set_cpu_operators` instead." ) - return self._option.set_openvino_cpu_operators(operators) + self._option.openvino_option.set_cpu_operators(operators) def enable_paddle_log_info(self): """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. @@ -415,17 +415,26 @@ class RuntimeOption: def enable_lite_fp16(self): """Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. """ - return self._option.enable_lite_fp16() + logging.warning( + "`RuntimeOption.enable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = True` instead." + ) + self._option.paddle_lite_option.enable_fp16 = True def disable_lite_fp16(self): """Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. """ - return self._option.disable_lite_fp16() + logging.warning( + "`RuntimeOption.disable_lite_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.enable_fp16 = False` instead." + ) + self._option.paddle_lite_option.enable_fp16 = False def set_lite_power_mode(self, mode): """Set POWER mode while using Paddle Lite backend on ARM CPU. """ - return self._option.set_lite_power_mode(mode) + logging.warning( + "`RuntimeOption.set_lite_powermode` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_lite_option.power_mode = {}` instead.". + format(mode)) + self._option.paddle_lite_option.power_mode = mode def set_trt_input_shape(self, tensor_name, @@ -439,30 +448,42 @@ class RuntimeOption: :param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape :param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape """ + logging.warning( + "`RuntimeOption.set_trt_input_shape` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.set_shape()` instead." + ) if opt_shape is None and max_shape is None: opt_shape = min_shape max_shape = min_shape else: assert opt_shape is not None and max_shape is not None, "Set min_shape only, or set min_shape, opt_shape, max_shape both." - return self._option.set_trt_input_shape(tensor_name, min_shape, - opt_shape, max_shape) + return self._option.trt_option.set_shape(tensor_name, min_shape, + opt_shape, max_shape) def set_trt_cache_file(self, cache_file_path): """Set a cache file path while using TensorRT backend. While loading a Paddle/ONNX model with set_trt_cache_file("./tensorrt_cache/model.trt"), if file `./tensorrt_cache/model.trt` exists, it will skip building tensorrt engine and load the cache file directly; if file `./tensorrt_cache/model.trt` doesn't exist, it will building tensorrt engine and save the engine as binary string to the cache file. :param cache_file_path: (str)Path of tensorrt cache file """ - return self._option.set_trt_cache_file(cache_file_path) + logging.warning( + "`RuntimeOption.set_trt_cache_file` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.serialize_file = {}` instead.". + format(cache_file_path)) + self._option.trt_option.serialize_file = cache_file_path def enable_trt_fp16(self): """Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference. """ - return self._option.enable_trt_fp16() + logging.warning( + "`RuntimeOption.enable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = True` instead." + ) + self._option.trt_option.enable_fp16 = True def disable_trt_fp16(self): """Disable half precision inference while suing TensorRT backend. """ - return self._option.disable_trt_fp16() + logging.warning( + "`RuntimeOption.disable_trt_fp16` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.enable_fp16 = False` instead." + ) + self._option.trt_option.enable_fp16 = False def enable_pinned_memory(self): """Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend. @@ -482,12 +503,18 @@ class RuntimeOption: def set_trt_max_workspace_size(self, trt_max_workspace_size): """Set max workspace size while using TensorRT backend. """ - return self._option.set_trt_max_workspace_size(trt_max_workspace_size) + logging.warning( + "`RuntimeOption.set_trt_max_workspace_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_workspace_size = {}` instead.". + format(trt_max_workspace_size)) + self._option.trt_option.max_workspace_size = trt_max_workspace_size def set_trt_max_batch_size(self, trt_max_batch_size): """Set max batch size while using TensorRT backend. """ - return self._option.set_trt_max_batch_size(trt_max_batch_size) + logging.warning( + "`RuntimeOption.set_trt_max_batch_size` will be deprecated in v1.2.0, please use `RuntimeOption.trt_option.max_batch_size = {}` instead.". + format(trt_max_batch_size)) + self._option.trt_option.max_batch_size = trt_max_batch_size def enable_paddle_trt_collect_shape(self): """Enable collect subgraph shape information while using Paddle Inference with TensorRT @@ -558,6 +585,14 @@ class RuntimeOption: """ return self._option.ort_option + @property + def trt_option(self): + """Get TrtBackendOption object to configure TensorRT backend + + :return TrtBackendOption + """ + return self._option.trt_option + def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50): """Set the profile mode as 'true'. :param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. diff --git a/serving/docs/EN/model_configuration-en.md b/serving/docs/EN/model_configuration-en.md index 2f9ee14ca..88f72e3b9 100644 --- a/serving/docs/EN/model_configuration-en.md +++ b/serving/docs/EN/model_configuration-en.md @@ -162,7 +162,8 @@ optimization { gpu_execution_accelerator : [ { name : "tensorrt" - # Use FP16 inference in TensorRT. You can also choose: trt_fp32, trt_int8 + # Use FP16 inference in TensorRT. You can also choose: trt_fp32 + # If the loaded model is a quantized model, this precision will be int8 automatically parameters { key: "precision" value: "trt_fp16" } } ] @@ -203,4 +204,4 @@ optimization { } ] }} -``` \ No newline at end of file +``` diff --git a/serving/docs/zh_CN/model_configuration.md b/serving/docs/zh_CN/model_configuration.md index 60803121c..03f8e09af 100644 --- a/serving/docs/zh_CN/model_configuration.md +++ b/serving/docs/zh_CN/model_configuration.md @@ -162,7 +162,8 @@ optimization { gpu_execution_accelerator : [ { name : "tensorrt" - # 使用TensorRT的FP16推理,其他可选项为: trt_fp32、trt_int8 + # 使用TensorRT的FP16推理,其他可选项为: trt_fp32 + # 如果加载的是量化模型,此精度设置无效,会默认使用int8进行推理 parameters { key: "precision" value: "trt_fp16" } } ] diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 79479609c..062a8476b 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -168,7 +168,10 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model), model_load_(false), main_runtime_(nullptr), is_clone_(true) { + : BackendModel(triton_model), + model_load_(false), + main_runtime_(nullptr), + is_clone_(true) { // Create runtime options that will be cloned and used for each // instance when creating that instance's runtime. runtime_options_.reset(new fastdeploy::RuntimeOption()); @@ -227,14 +230,14 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) ParseBoolValue(value_string, &pd_enable_mkldnn)); runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); } else if (param_key == "use_paddle_log") { - runtime_options_->EnablePaddleLogInfo(); + runtime_options_->EnablePaddleLogInfo(); } else if (param_key == "num_streams") { - int num_streams; - THROW_IF_BACKEND_MODEL_ERROR( + int num_streams; + THROW_IF_BACKEND_MODEL_ERROR( ParseIntValue(value_string, &num_streams)); - runtime_options_->SetOpenVINOStreams(num_streams); + runtime_options_->openvino_option.num_streams = num_streams; } else if (param_key == "is_clone") { - THROW_IF_BACKEND_MODEL_ERROR( + THROW_IF_BACKEND_MODEL_ERROR( ParseBoolValue(value_string, &is_clone_)); } else if (param_key == "use_ipu") { // runtime_options_->UseIpu(); @@ -271,11 +274,11 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) std::vector shape; FDParseShape(params, input_name, &shape); if (name == "min_shape") { - runtime_options_->trt_min_shape[input_name] = shape; + runtime_options_->trt_option.min_shape[input_name] = shape; } else if (name == "max_shape") { - runtime_options_->trt_max_shape[input_name] = shape; + runtime_options_->trt_option.max_shape[input_name] = shape; } else { - runtime_options_->trt_opt_shape[input_name] = shape; + runtime_options_->trt_option.opt_shape[input_name] = shape; } } } @@ -292,12 +295,10 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) std::transform(value_string.begin(), value_string.end(), value_string.begin(), ::tolower); if (value_string == "trt_fp16") { - runtime_options_->EnableTrtFP16(); - } else if (value_string == "trt_int8") { - // TODO(liqi): use EnableTrtINT8 - runtime_options_->trt_enable_int8 = true; + runtime_options_->trt_option.enable_fp16 = true; } else if (value_string == "pd_fp16") { - // TODO(liqi): paddle inference don't currently have interface for fp16. + // TODO(liqi): paddle inference don't currently have interface + // for fp16. } // } else if( param_key == "max_batch_size") { // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( @@ -307,15 +308,15 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) // value_string, // &runtime_options_->trt_max_workspace_size)); } else if (param_key == "cache_file") { - runtime_options_->SetTrtCacheFile(value_string); + runtime_options_->trt_option.serialize_file = value_string; } else if (param_key == "use_paddle") { runtime_options_->EnablePaddleToTrt(); } else if (param_key == "use_paddle_log") { runtime_options_->EnablePaddleLogInfo(); } else if (param_key == "is_clone") { THROW_IF_BACKEND_MODEL_ERROR( - ParseBoolValue(value_string, &is_clone_)); - } + ParseBoolValue(value_string, &is_clone_)); + } } } } @@ -330,17 +331,17 @@ TRITONSERVER_Error* ModelState::LoadModel( const int32_t instance_group_device_id, std::string* model_path, std::string* params_path, fastdeploy::Runtime** runtime, cudaStream_t stream) { - // FastDeploy Runtime creation is not thread-safe, so multiple creations // are serialized with a global lock. // The Clone interface can be invoked only when the main_runtime_ is created. static std::mutex global_context_mu; std::lock_guard glock(global_context_mu); - if(model_load_ && is_clone_) { - if(main_runtime_ == nullptr) { - return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, - std::string("main_runtime is nullptr").c_str()); + if (model_load_ && is_clone_) { + if (main_runtime_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string("main_runtime is nullptr").c_str()); } *runtime = main_runtime_->Clone((void*)stream, instance_group_device_id); } else { @@ -367,21 +368,21 @@ TRITONSERVER_Error* ModelState::LoadModel( if (not exists) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_NOT_FOUND, - std::string("Paddle params should be named as 'model.pdiparams' or " - "not provided.'") + std::string( + "Paddle params should be named as 'model.pdiparams' or " + "not provided.'") .c_str()); } - runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE; - runtime_options_->model_file = *model_path; - runtime_options_->params_file = *params_path; + runtime_options_->SetModelPath(*model_path, *params_path, + fastdeploy::ModelFormat::PADDLE); } else { - runtime_options_->model_format = fastdeploy::ModelFormat::ONNX; - runtime_options_->model_file = *model_path; + runtime_options_->SetModelPath(*model_path, "", + fastdeploy::ModelFormat::ONNX); } } // GPU - #ifdef TRITON_ENABLE_GPU +#ifdef TRITON_ENABLE_GPU if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { runtime_options_->UseGpu(instance_group_device_id); @@ -389,17 +390,17 @@ TRITONSERVER_Error* ModelState::LoadModel( } else if (runtime_options_->device != fastdeploy::Device::IPU) { runtime_options_->UseCpu(); } - #else +#else if (runtime_options_->device != fastdeploy::Device::IPU) { // If Device is set to IPU, just skip CPU setting. runtime_options_->UseCpu(); } - #endif // TRITON_ENABLE_GPU +#endif // TRITON_ENABLE_GPU *runtime = main_runtime_ = new fastdeploy::Runtime(); if (!(*runtime)->Init(*runtime_options_)) { return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, - std::string("Runtime init error").c_str()); + std::string("Runtime init error").c_str()); } model_load_ = true; } @@ -942,8 +943,8 @@ void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests, if (!all_response_failed) { FD_RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, - all_response_failed, - Run(&responses, request_count)); + all_response_failed, + Run(&responses, request_count)); } uint64_t compute_end_ns = 0; @@ -1067,17 +1068,16 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors( allowed_input_types; if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { allowed_input_types = {{TRITONSERVER_MEMORY_GPU, DeviceId()}, - {TRITONSERVER_MEMORY_CPU_PINNED, 0}, - {TRITONSERVER_MEMORY_CPU, 0}}; + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, + {TRITONSERVER_MEMORY_CPU, 0}}; } else { allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, - {TRITONSERVER_MEMORY_CPU, 0}}; + {TRITONSERVER_MEMORY_CPU, 0}}; } - RETURN_IF_ERROR( - collector->ProcessTensor( - input_name, nullptr, 0, allowed_input_types, &input_buffer, - &batchn_byte_size, &memory_type, &memory_type_id)); + RETURN_IF_ERROR(collector->ProcessTensor( + input_name, nullptr, 0, allowed_input_types, &input_buffer, + &batchn_byte_size, &memory_type, &memory_type_id)); int32_t device_id = -1; fastdeploy::Device device; @@ -1089,9 +1089,9 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors( } fastdeploy::FDTensor fdtensor(in_name); - fdtensor.SetExternalData( - batchn_shape, ConvertDataTypeToFD(input_datatype), - const_cast(input_buffer), device, device_id); + fdtensor.SetExternalData(batchn_shape, ConvertDataTypeToFD(input_datatype), + const_cast(input_buffer), device, + device_id); runtime_->BindInputTensor(in_name, fdtensor); } @@ -1130,23 +1130,22 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors( for (auto& output_name : output_names_) { auto* output_tensor = runtime_->GetOutputTensor(output_name); if (output_tensor == nullptr) { - RETURN_IF_ERROR( - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - (std::string("output tensor '") + output_name + "' is not found") - .c_str())); + RETURN_IF_ERROR(TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("output tensor '") + output_name + "' is not found") + .c_str())); } TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; int64_t memory_type_id = 0; - if(output_tensor->device == fastdeploy::Device::GPU) { + if (output_tensor->device == fastdeploy::Device::GPU) { memory_type = TRITONSERVER_MEMORY_GPU; memory_type_id = DeviceId(); } responder.ProcessTensor( output_tensor->name, ConvertFDType(output_tensor->dtype), output_tensor->shape, - reinterpret_cast(output_tensor->MutableData()), - memory_type, memory_type_id); + reinterpret_cast(output_tensor->MutableData()), memory_type, + memory_type_id); } // Finalize and wait for any pending buffer copies.