diff --git a/fastdeploy/backends/backend.h b/fastdeploy/backends/backend.h index 620aea9f4..652d94cb8 100644 --- a/fastdeploy/backends/backend.h +++ b/fastdeploy/backends/backend.h @@ -21,6 +21,7 @@ #include "fastdeploy/backends/common/multiclass_nms.h" #include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/core/fd_type.h" namespace fastdeploy { @@ -63,6 +64,11 @@ class BaseBackend { virtual std::vector GetOutputInfos() = 0; virtual bool Infer(std::vector& inputs, std::vector* outputs) = 0; + virtual std::unique_ptr Clone(void *stream = nullptr, + int device_id = -1) { + FDERROR << "Clone no support" << std::endl; + return nullptr; + } }; } // namespace fastdeploy diff --git a/fastdeploy/backends/openvino/ov_backend.cc b/fastdeploy/backends/openvino/ov_backend.cc index f205b48e2..5a664fc87 100644 --- a/fastdeploy/backends/openvino/ov_backend.cc +++ b/fastdeploy/backends/openvino/ov_backend.cc @@ -74,6 +74,8 @@ ov::element::Type FDDataTypeToOV(const FDDataType& type) { return ov::element::f32; } +ov::Core OpenVINOBackend::core_; + void OpenVINOBackend::InitTensorInfo( const std::vector>& ov_outputs, std::map* tensor_infos) { @@ -96,10 +98,6 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file, return false; } option_ = option; - ov::AnyMap properties; - if (option_.cpu_thread_num > 0) { - properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; - } std::shared_ptr model = core_.read_model(model_file, params_file); @@ -149,7 +147,19 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file, output_infos_.push_back(iter->second); } + ov::AnyMap properties; + if (option_.cpu_thread_num > 0) { + properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; + } + if (option_.ov_num_streams == -1) { + properties["NUM_STREAMS"] = ov::streams::AUTO; + } else if (option_.ov_num_streams == -2) { + properties["NUM_STREAMS"] = ov::streams::NUMA; + } else if (option_.ov_num_streams > 0) { + properties["NUM_STREAMS"] = option_.ov_num_streams; + } compiled_model_ = core_.compile_model(model, "CPU", properties); + request_ = compiled_model_.create_infer_request(); initialized_ = true; return true; @@ -185,10 +195,6 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file, return false; } option_ = option; - ov::AnyMap properties; - if (option_.cpu_thread_num > 0) { - properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; - } std::shared_ptr model = core_.read_model(model_file); @@ -238,8 +244,21 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file, output_infos_.push_back(iter->second); } + ov::AnyMap properties; + if (option_.cpu_thread_num > 0) { + properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; + } + if (option_.ov_num_streams == -1) { + properties["NUM_STREAMS"] = ov::streams::AUTO; + } else if (option_.ov_num_streams == -2) { + properties["NUM_STREAMS"] = ov::streams::NUMA; + } else if (option_.ov_num_streams > 0) { + properties["NUM_STREAMS"] = option_.ov_num_streams; + } compiled_model_ = core_.compile_model(model, "CPU", properties); + request_ = compiled_model_.create_infer_request(); + initialized_ = true; return true; } @@ -281,4 +300,14 @@ bool OpenVINOBackend::Infer(std::vector& inputs, return true; } +std::unique_ptr OpenVINOBackend::Clone(void *stream, int device_id) { + std::unique_ptr new_backend = utils::make_unique(); + auto casted_backend = dynamic_cast(new_backend.get()); + casted_backend->option_ = option_; + casted_backend->request_ = compiled_model_.create_infer_request(); + casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end()); + casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end()); + return new_backend; +} + } // namespace fastdeploy diff --git a/fastdeploy/backends/openvino/ov_backend.h b/fastdeploy/backends/openvino/ov_backend.h index 5dd362d52..b7d77e58f 100644 --- a/fastdeploy/backends/openvino/ov_backend.h +++ b/fastdeploy/backends/openvino/ov_backend.h @@ -20,17 +20,20 @@ #include #include "fastdeploy/backends/backend.h" +#include "fastdeploy/utils/unique_ptr.h" #include "openvino/openvino.hpp" namespace fastdeploy { struct OpenVINOBackendOption { - int cpu_thread_num = 8; + int cpu_thread_num = -1; + int ov_num_streams = 1; std::map> shape_infos; }; class OpenVINOBackend : public BaseBackend { public: + static ov::Core core_; OpenVINOBackend() {} virtual ~OpenVINOBackend() = default; @@ -54,10 +57,13 @@ class OpenVINOBackend : public BaseBackend { std::vector GetInputInfos() override; std::vector GetOutputInfos() override; + std::unique_ptr Clone(void *stream = nullptr, + int device_id = -1) override; + private: void InitTensorInfo(const std::vector>& ov_outputs, std::map* tensor_infos); - ov::Core core_; + ov::CompiledModel compiled_model_; ov::InferRequest request_; OpenVINOBackendOption option_; diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index 61e5fb414..70d8305c5 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -216,6 +216,30 @@ bool PaddleBackend::Infer(std::vector& inputs, return true; } +std::unique_ptr PaddleBackend::Clone(void *stream, int device_id) { + std::unique_ptr new_backend = utils::make_unique(); + auto casted_backend = dynamic_cast(new_backend.get()); + if(device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) { + auto clone_option = option_; + clone_option.gpu_id = device_id; + clone_option.external_stream_ = stream; + casted_backend->InitFromPaddle(clone_option.model_file, + clone_option.params_file, + clone_option); + FDWARNING << "The target device id:" + << device_id + << " is different from current device id:" + << option_.gpu_id + << ", cannot share memory with current engine." + << std::endl; + return new_backend; + } + casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end()); + casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end()); + casted_backend->predictor_ = std::move(predictor_->Clone(stream)); + return new_backend; +} + #ifdef ENABLE_TRT_BACKEND void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) { std::map> max_shape; diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h index 43f8e67e6..0c674494e 100755 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -24,6 +24,7 @@ #include "paddle2onnx/converter.h" #endif #include "paddle_inference_api.h" // NOLINT +#include "fastdeploy/utils/unique_ptr.h" #ifdef ENABLE_TRT_BACKEND #include "fastdeploy/backends/tensorrt/trt_backend.h" @@ -43,6 +44,9 @@ struct IpuOption { }; struct PaddleBackendOption { + std::string model_file = ""; // Path of model file + std::string params_file = ""; // Path of parameters file, can be empty + #ifdef WITH_GPU bool use_gpu = true; #else @@ -110,6 +114,9 @@ class PaddleBackend : public BaseBackend { int NumOutputs() const override { return outputs_desc_.size(); } + std::unique_ptr Clone(void *stream = nullptr, + int device_id = -1) override; + TensorInfo GetInputInfo(int index) override; TensorInfo GetOutputInfo(int index) override; std::vector GetInputInfos() override; diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index ba6c32951..2306cb239 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -285,6 +285,7 @@ bool TrtBackend::Infer(std::vector& inputs, BuildTrtEngine(); } + cudaSetDevice(option_.gpu_id); SetInputs(inputs); AllocateOutputsBuffer(outputs); @@ -356,13 +357,17 @@ void TrtBackend::GetInputOutputInfo() { outputs_device_buffer_[name] = FDDeviceBuffer(dtype); casted_output_tensors_[name] = FDTensor(); } + io_name_index_[name] = i; } bindings_.resize(num_binds); } void TrtBackend::SetInputs(const std::vector& inputs) { for (const auto& item : inputs) { - auto idx = engine_->getBindingIndex(item.name.c_str()); + // auto idx = engine_->getBindingIndex(item.name.c_str()); + auto iter = io_name_index_.find(item.name); + FDASSERT(iter != io_name_index_.end(), "TRTBackend SetInputs not find name:%s", item.name.c_str()); + auto idx = iter->second; std::vector shape(item.shape.begin(), item.shape.end()); auto dims = ToDims(shape); context_->setBindingDimensions(idx, dims); @@ -410,7 +415,10 @@ void TrtBackend::AllocateOutputsBuffer(std::vector* outputs) { outputs->resize(outputs_desc_.size()); } for (size_t i = 0; i < outputs_desc_.size(); ++i) { - auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str()); + // auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str()); + auto idx_iter = io_name_index_.find(outputs_desc_[i].name); + FDASSERT(idx_iter != io_name_index_.end(), "TRTBackend Outputs not find name:%s", outputs_desc_[i].name.c_str()); + auto idx = idx_iter->second; auto output_dims = context_->getBindingDimensions(idx); // find the original index of output @@ -673,4 +681,47 @@ std::vector TrtBackend::GetOutputInfos() { return infos; } +std::unique_ptr TrtBackend::Clone(void *stream, int device_id) { + std::unique_ptr new_backend = utils::make_unique(); + auto casted_backend = dynamic_cast(new_backend.get()); + if(device_id > 0 && device_id != option_.gpu_id) { + auto clone_option = option_; + clone_option.gpu_id = device_id; + clone_option.external_stream_ = stream; + if (option_.model_format == ModelFormat::ONNX) { + FDASSERT(casted_backend->InitFromOnnx(option_.model_file, clone_option), + "Clone model from ONNX failed while initialize TrtBackend."); + } else { + FDASSERT(casted_backend->InitFromPaddle(option_.model_file, + option_.params_file, clone_option), + "Clone model from Paddle failed while initialize TrtBackend."); + } + FDWARNING << "The target device id:" + << device_id + << " is different from current device id:" + << option_.gpu_id + << ", cannot share memory with current engine." + << std::endl; + return new_backend; + } + cudaSetDevice(option_.gpu_id); + casted_backend->option_.gpu_id = option_.gpu_id; + if (stream) { + casted_backend->stream_ = reinterpret_cast(stream); + } else { + FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0, + "[ERROR] Error occurs while clone calling cudaStreamCreate()."); + } + casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end()); + casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end()); + casted_backend->outputs_order_.insert(outputs_order_.begin(), outputs_order_.end()); + casted_backend->shape_range_info_.insert(shape_range_info_.begin(), shape_range_info_.end()); + casted_backend->engine_ = engine_; + casted_backend->context_ = std::shared_ptr( + casted_backend->engine_->createExecutionContext()); + casted_backend->GetInputOutputInfo(); + FDINFO << "TRTBackend clone finish." << std::endl; + return new_backend; +} + } // namespace fastdeploy diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index cb107af49..7ef931f90 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -25,6 +25,7 @@ #include "NvOnnxParser.h" #include "fastdeploy/backends/backend.h" #include "fastdeploy/backends/tensorrt/utils.h" +#include "fastdeploy/utils/unique_ptr.h" class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { public: @@ -45,7 +46,7 @@ class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { void writeCalibrationCache(const void* cache, size_t length) noexcept override { - std::cout << "NOT IMPLEMENT." << std::endl; + fastdeploy::FDERROR << "NOT IMPLEMENT." << std::endl; } private: @@ -62,6 +63,11 @@ struct TrtValueInfo { }; struct TrtBackendOption { + std::string model_file = ""; // Path of model file + std::string params_file = ""; // Path of parameters file, can be empty + // format of input model + ModelFormat model_format = ModelFormat::AUTOREC; + int gpu_id = 0; bool enable_fp16 = false; bool enable_int8 = false; @@ -99,6 +105,8 @@ class TrtBackend : public BaseBackend { TensorInfo GetOutputInfo(int index); std::vector GetInputInfos() override; std::vector GetOutputInfos() override; + std::unique_ptr Clone(void *stream = nullptr, + int device_id = -1) override; ~TrtBackend() { if (parser_) { @@ -119,6 +127,7 @@ class TrtBackend : public BaseBackend { std::vector outputs_desc_; std::map inputs_device_buffer_; std::map outputs_device_buffer_; + std::map io_name_index_; std::string calibration_str_; diff --git a/fastdeploy/core/fd_type.cc b/fastdeploy/core/fd_type.cc index 45ca90a1b..5712bb278 100644 --- a/fastdeploy/core/fd_type.cc +++ b/fastdeploy/core/fd_type.cc @@ -182,4 +182,31 @@ const FDDataType TypeToDataType::dtype = UINT8; template <> const FDDataType TypeToDataType::dtype = INT8; +std::string Str(const ModelFormat& f) { + if (f == ModelFormat::PADDLE) { + return "ModelFormat::PADDLE"; + } else if (f == ModelFormat::ONNX) { + return "ModelFormat::ONNX"; + }else if (f == ModelFormat::RKNN) { + return "ModelFormat::RKNN"; + } else if (f == ModelFormat::TORCHSCRIPT) { + return "ModelFormat::TORCHSCRIPT"; + } + return "UNKNOWN-ModelFormat"; +} + +std::ostream& operator<<(std::ostream& out, const ModelFormat& format) { + if (format == ModelFormat::PADDLE) { + out << "ModelFormat::PADDLE"; + } else if (format == ModelFormat::ONNX) { + out << "ModelFormat::ONNX"; + } else if (format == ModelFormat::RKNN) { + out << "ModelFormat::RKNN"; + } else if (format == ModelFormat::TORCHSCRIPT) { + out << "ModelFormat::TORCHSCRIPT"; + } + out << "UNKNOWN-ModelFormat"; + return out; +} + } // namespace fastdeploy diff --git a/fastdeploy/core/fd_type.h b/fastdeploy/core/fd_type.h index 5236601b0..131de20d4 100644 --- a/fastdeploy/core/fd_type.h +++ b/fastdeploy/core/fd_type.h @@ -65,4 +65,16 @@ struct FASTDEPLOY_DECL TypeToDataType { static const FDDataType dtype; }; +/*! Deep learning model format */ +enum ModelFormat { + AUTOREC, ///< Auto recognize the model format by model file name + PADDLE, ///< Model with paddlepaddle format + ONNX, ///< Model with ONNX format + RKNN, ///< Model with RKNN format + TORCHSCRIPT, ///< Model with TorchScript format +}; + +FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, + const ModelFormat& format); + } // namespace fastdeploy diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 0a9dff535..94ea9de0b 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -102,19 +102,6 @@ std::string Str(const Backend& b) { return "UNKNOWN-Backend"; } -std::string Str(const ModelFormat& f) { - if (f == ModelFormat::PADDLE) { - return "ModelFormat::PADDLE"; - } else if (f == ModelFormat::ONNX) { - return "ModelFormat::ONNX"; - }else if (f == ModelFormat::RKNN) { - return "ModelFormat::RKNN"; - } else if (f == ModelFormat::TORCHSCRIPT) { - return "ModelFormat::TORCHSCRIPT"; - } - return "UNKNOWN-ModelFormat"; -} - std::ostream& operator<<(std::ostream& out, const Backend& backend) { if (backend == Backend::ORT) { out << "Backend::ORT"; @@ -135,20 +122,6 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) { return out; } -std::ostream& operator<<(std::ostream& out, const ModelFormat& format) { - if (format == ModelFormat::PADDLE) { - out << "ModelFormat::PADDLE"; - } else if (format == ModelFormat::ONNX) { - out << "ModelFormat::ONNX"; - } else if (format == ModelFormat::RKNN) { - out << "ModelFormat::RKNN"; - } else if (format == ModelFormat::TORCHSCRIPT) { - out << "ModelFormat::TORCHSCRIPT"; - } - out << "UNKNOWN-ModelFormat"; - return out; -} - bool CheckModelFormat(const std::string& model_file, const ModelFormat& model_format) { if (model_format == ModelFormat::PADDLE) { @@ -411,6 +384,10 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) { trt_serialize_file = cache_file_path; } +void RuntimeOption::SetOpenVINOStreams(int num_streams) { + ov_num_streams = num_streams; +} + bool Runtime::Compile(std::vector>& prewarm_tensors, const RuntimeOption& _option) { #ifdef ENABLE_POROS_BACKEND @@ -582,6 +559,8 @@ bool Runtime::Infer(std::vector& input_tensors, void Runtime::CreatePaddleBackend() { #ifdef ENABLE_PADDLE_BACKEND auto pd_option = PaddleBackendOption(); + pd_option.model_file = option.model_file; + pd_option.params_file = option.params_file; pd_option.enable_mkldnn = option.pd_enable_mkldnn; pd_option.enable_log_info = option.pd_enable_log_info; pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size; @@ -642,6 +621,7 @@ void Runtime::CreateOpenVINOBackend() { #ifdef ENABLE_OPENVINO_BACKEND auto ov_option = OpenVINOBackendOption(); ov_option.cpu_thread_num = option.cpu_thread_num; + ov_option.ov_num_streams = option.ov_num_streams; FDASSERT(option.model_format == ModelFormat::PADDLE || option.model_format == ModelFormat::ONNX, "OpenVINOBackend only support model format of ModelFormat::PADDLE / " @@ -699,6 +679,9 @@ void Runtime::CreateOrtBackend() { void Runtime::CreateTrtBackend() { #ifdef ENABLE_TRT_BACKEND auto trt_option = TrtBackendOption(); + trt_option.model_file = option.model_file; + trt_option.params_file = option.params_file; + trt_option.model_format = option.model_format; trt_option.gpu_id = option.device_id; trt_option.enable_fp16 = option.trt_enable_fp16; trt_option.enable_int8 = option.trt_enable_int8; @@ -771,4 +754,26 @@ void Runtime::CreateRKNPU2Backend() { #endif } +Runtime* Runtime::Clone(void* stream, int device_id) { + Runtime* runtime = new Runtime(); + if (option.backend != Backend::OPENVINO + && option.backend != Backend::PDINFER + && option.backend != Backend::TRT + ) { + runtime->Init(option); + FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \ + clone engine to reduce CPU/GPU memory usage now. For " + << option.backend + << ", FastDeploy will create a new engine which \ + will not share memory with the current runtime." + << std::endl; + return runtime; + } + FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " << Str(option.device) + << "." << std::endl; + runtime->option = option; + runtime->backend_ = backend_->Clone(stream, device_id); + return runtime; +} + } // namespace fastdeploy diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index e50e262c2..7ab6f1fb2 100644 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -35,38 +35,27 @@ namespace fastdeploy { /*! Inference backend supported in FastDeploy */ enum Backend { - UNKNOWN, ///< Unknown inference backend + UNKNOWN, ///< Unknown inference backend ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU - TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only - PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU - POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU - OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only + TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only + PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU + POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU + OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only }; -/*! Deep learning model format */ -enum ModelFormat { - AUTOREC, ///< Auto recognize the model format by model file name - PADDLE, ///< Model with paddlepaddle format - ONNX, ///< Model with ONNX format - RKNN, ///< Model with RKNN format - TORCHSCRIPT, ///< Model with TorchScript format -}; - FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const Backend& backend); -FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, - const ModelFormat& format); /*! Paddle Lite power mode for mobile device. */ enum LitePowerMode { - LITE_POWER_HIGH = 0, ///< Use Lite Backend with high power mode - LITE_POWER_LOW = 1, ///< Use Lite Backend with low power mode - LITE_POWER_FULL = 2, ///< Use Lite Backend with full power mode - LITE_POWER_NO_BIND = 3, ///< Use Lite Backend with no bind power mode - LITE_POWER_RAND_HIGH = 4, ///< Use Lite Backend with rand high mode - LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode + LITE_POWER_HIGH = 0, ///< Use Lite Backend with high power mode + LITE_POWER_LOW = 1, ///< Use Lite Backend with low power mode + LITE_POWER_FULL = 2, ///< Use Lite Backend with full power mode + LITE_POWER_NO_BIND = 3, ///< Use Lite Backend with no bind power mode + LITE_POWER_RAND_HIGH = 4, ///< Use Lite Backend with rand high mode + LITE_POWER_RAND_LOW = 5 ///< Use Lite Backend with rand low power mode }; FASTDEPLOY_DECL std::string Str(const Backend& b); @@ -105,8 +94,10 @@ struct FASTDEPLOY_DECL RuntimeOption { /// Use Nvidia GPU to inference void UseGpu(int gpu_id = 0); - void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name = fastdeploy::rknpu2::CpuName::RK3588, - fastdeploy::rknpu2::CoreMask rknpu2_core = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0); + void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name + = fastdeploy::rknpu2::CpuName::RK3588, + fastdeploy::rknpu2::CoreMask rknpu2_core + = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0); void SetExternalStream(void* external_stream); @@ -242,6 +233,11 @@ struct FASTDEPLOY_DECL RuntimeOption { */ void DisablePaddleTrtCollectShape(); + /* + * @brief Set number of streams by the OpenVINO backends + */ + void SetOpenVINOStreams(int num_streams); + /** \Use Graphcore IPU to inference. * * \param[in] device_num the number of IPUs. @@ -331,13 +327,19 @@ struct FASTDEPLOY_DECL RuntimeOption { int unconst_ops_thres = -1; std::string poros_file = ""; + // ======Only for OpenVINO Backend======= + int ov_num_streams = 1; + // ======Only for RKNPU2 Backend======= - fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ = fastdeploy::rknpu2::CpuName::RK3588; - fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; + fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ + = fastdeploy::rknpu2::CpuName::RK3588; + fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ + = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; std::string model_file = ""; // Path of model file - std::string params_file = ""; // Path of parameters file, can be empty - ModelFormat model_format = ModelFormat::AUTOREC; // format of input model + std::string params_file = ""; // Path of parameters file, can be empty + // format of input model + ModelFormat model_format = ModelFormat::AUTOREC; }; /*! @brief Runtime object used to inference the loaded model on different devices @@ -384,6 +386,14 @@ struct FASTDEPLOY_DECL Runtime { */ std::vector GetOutputInfos(); + /** \brief Clone new Runtime when multiple instances of the same model are created + * + * \param[in] stream CUDA Stream, defualt param is nullptr + * \return new Runtime* by this clone + */ + Runtime* Clone(void* stream = nullptr, + int device_id = -1); + RuntimeOption option; private: @@ -395,4 +405,4 @@ struct FASTDEPLOY_DECL Runtime { void CreateRKNPU2Backend(); std::unique_ptr backend_; }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/serving/docs/zh_CN/model_configuration.md b/serving/docs/zh_CN/model_configuration.md index 7a19aa8fa..ce3abc075 100644 --- a/serving/docs/zh_CN/model_configuration.md +++ b/serving/docs/zh_CN/model_configuration.md @@ -142,8 +142,10 @@ optimization { cpu_execution_accelerator : [ { name : "openvino" - # 设置推理并行计算线程数为4 + # 设置推理并行计算线程数为4(所有实例总共线程数) parameters { key: "cpu_threads" value: "4" } + # 设置OpenVINO的num_streams(一般设置为跟实例数一致) + parameters { key: "num_streams" value: "1" } } ] } diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 32d0127e0..2e839b5ac 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -91,6 +91,9 @@ class ModelState : public BackendModel { // Runtime options used when creating a FastDeploy Runtime. std::unique_ptr runtime_options_; + bool model_load_; + fastdeploy::Runtime* main_runtime_; + bool is_clone_ = true; // model_outputs is a map that contains unique outputs that the model must // provide. In the model configuration, the output in the state configuration @@ -165,7 +168,7 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model) { + : BackendModel(triton_model), model_load_(false), main_runtime_(nullptr), is_clone_(true) { // Create runtime options that will be cloned and used for each // instance when creating that instance's runtime. runtime_options_.reset(new fastdeploy::RuntimeOption()); @@ -218,19 +221,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) THROW_IF_BACKEND_MODEL_ERROR( ParseIntValue(value_string, &cpu_thread_num)); runtime_options_->SetCpuThreadNum(cpu_thread_num); - // } else if (param_key == "graph_level") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, &runtime_options_->ort_graph_opt_level)); - // } else if (param_key == "inter_op_num_threads") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, - // &runtime_options_->ort_inter_op_num_threads)); - // } else if (param_key == "execution_mode") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, &runtime_options_->ort_execution_mode)); - // } else if (param_key == "capacity") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, &runtime_options_->pd_mkldnn_cache_size)); } else if (param_key == "use_mkldnn") { bool pd_enable_mkldnn; THROW_IF_BACKEND_MODEL_ERROR( @@ -238,8 +228,16 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); } else if (param_key == "use_paddle_log") { runtime_options_->EnablePaddleLogInfo(); + } else if (param_key == "num_streams") { + int num_streams; + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &num_streams)); + runtime_options_->SetOpenVINOStreams(num_streams); + } else if (param_key == "is_clone") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &is_clone_)); } else if (param_key == "use_ipu") { - runtime_options_->UseIpu(); + // runtime_options_->UseIpu(); } } } @@ -290,17 +288,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) std::string value_string; THROW_IF_BACKEND_MODEL_ERROR( params.MemberAsString(param_key.c_str(), &value_string)); - // if (param_key == "graph_level") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, &runtime_options_->ort_graph_opt_level)); - // } else if (param_key == "inter_op_num_threads") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, - // &runtime_options_->ort_inter_op_num_threads)); - // } else if (param_key == "execution_mode") { - // THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue( - // value_string, &runtime_options_->ort_execution_mode)); - // } if (param_key == "precision") { std::transform(value_string.begin(), value_string.end(), value_string.begin(), ::tolower); @@ -325,7 +312,10 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) runtime_options_->EnablePaddleToTrt(); } else if (param_key == "use_paddle_log") { runtime_options_->EnablePaddleLogInfo(); - } + } else if (param_key == "is_clone") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &is_clone_)); + } } } } @@ -340,64 +330,79 @@ TRITONSERVER_Error* ModelState::LoadModel( const int32_t instance_group_device_id, std::string* model_path, std::string* params_path, fastdeploy::Runtime** runtime, cudaStream_t stream) { - auto dir_path = JoinPath({RepositoryPath(), std::to_string(Version())}); - { - // ONNX Format - bool exists; - *model_path = JoinPath({dir_path, "model.onnx"}); - RETURN_IF_ERROR(FileExists(*model_path, &exists)); + + // FastDeploy Runtime creation is not thread-safe, so multiple creations + // are serialized with a global lock. + // The Clone interface can be invoked only when the main_runtime_ is created. + static std::mutex global_context_mu; + std::lock_guard glock(global_context_mu); - // Paddle Formax - if (not exists) { - *model_path = JoinPath({dir_path, "model.pdmodel"}); - RETURN_IF_ERROR(FileExists(*model_path, &exists)); - if (not exists) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string( - "Model should be named as 'model.onnx' or 'model.pdmodel'") - .c_str()); - } - *params_path = JoinPath({dir_path, "model.pdiparams"}); - RETURN_IF_ERROR(FileExists(*params_path, &exists)); - if (not exists) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string("Paddle params should be named as 'model.pdiparams' or " - "not provided.'") - .c_str()); - } - runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE; - runtime_options_->model_file = *model_path; - runtime_options_->params_file = *params_path; - } else { - runtime_options_->model_format = fastdeploy::ModelFormat::ONNX; - runtime_options_->model_file = *model_path; + if(model_load_ && is_clone_) { + if(main_runtime_ == nullptr) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, + std::string("main_runtime is nullptr").c_str()); } - } - - // GPU -#ifdef TRITON_ENABLE_GPU - if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || - (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { - runtime_options_->UseGpu(instance_group_device_id); - runtime_options_->SetExternalStream((void*)stream); + *runtime = main_runtime_->Clone((void*)stream, instance_group_device_id); } else { - runtime_options_->UseCpu(); - } -#else - if (runtime_options_->device != fastdeploy::Device::IPU) { - // If Device is set to IPU, just skip CPU setting. - runtime_options_->UseCpu(); - } -#endif // TRITON_ENABLE_GPU + auto dir_path = JoinPath({RepositoryPath(), std::to_string(Version())}); + { + // ONNX Format + bool exists; + *model_path = JoinPath({dir_path, "model.onnx"}); + RETURN_IF_ERROR(FileExists(*model_path, &exists)); - *runtime = new fastdeploy::Runtime(); - if (!(*runtime)->Init(*runtime_options_)) { - return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, - std::string("Runtime init error").c_str()); - } + // Paddle Formax + if (not exists) { + *model_path = JoinPath({dir_path, "model.pdmodel"}); + RETURN_IF_ERROR(FileExists(*model_path, &exists)); + if (not exists) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Model should be named as 'model.onnx' or 'model.pdmodel'") + .c_str()); + } + *params_path = JoinPath({dir_path, "model.pdiparams"}); + RETURN_IF_ERROR(FileExists(*params_path, &exists)); + if (not exists) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string("Paddle params should be named as 'model.pdiparams' or " + "not provided.'") + .c_str()); + } + runtime_options_->model_format = fastdeploy::ModelFormat::PADDLE; + runtime_options_->model_file = *model_path; + runtime_options_->params_file = *params_path; + } else { + runtime_options_->model_format = fastdeploy::ModelFormat::ONNX; + runtime_options_->model_file = *model_path; + } + } + // GPU + #ifdef TRITON_ENABLE_GPU + if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || + (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { + runtime_options_->UseGpu(instance_group_device_id); + runtime_options_->SetExternalStream((void*)stream); + } else if (runtime_options_->device != fastdeploy::Device::IPU) { + runtime_options_->UseCpu(); + } + #else + if (runtime_options_->device != fastdeploy::Device::IPU) { + // If Device is set to IPU, just skip CPU setting. + runtime_options_->UseCpu(); + } + #endif // TRITON_ENABLE_GPU + + *runtime = main_runtime_ = new fastdeploy::Runtime(); + if (!(*runtime)->Init(*runtime_options_)) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND, + std::string("Runtime init error").c_str()); + } + model_load_ = true; + } return nullptr; // success }