[Backend & Serving] Serving and Runtime support Clone (#464)

* Add Serving and Runtime use Clone

* support TRT, OpenVINO and Paddle Backend

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
heliqi
2022-11-04 17:16:40 +08:00
committed by GitHub
parent 61634caf28
commit 277bec38c7
13 changed files with 343 additions and 150 deletions

View File

@@ -21,6 +21,7 @@
#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/fd_type.h"
namespace fastdeploy {
@@ -63,6 +64,11 @@ class BaseBackend {
virtual std::vector<TensorInfo> GetOutputInfos() = 0;
virtual bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) = 0;
virtual std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
int device_id = -1) {
FDERROR << "Clone no support" << std::endl;
return nullptr;
}
};
} // namespace fastdeploy

View File

@@ -74,6 +74,8 @@ ov::element::Type FDDataTypeToOV(const FDDataType& type) {
return ov::element::f32;
}
ov::Core OpenVINOBackend::core_;
void OpenVINOBackend::InitTensorInfo(
const std::vector<ov::Output<ov::Node>>& ov_outputs,
std::map<std::string, TensorInfo>* tensor_infos) {
@@ -96,10 +98,6 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
@@ -149,7 +147,19 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
output_infos_.push_back(iter->second);
}
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
if (option_.ov_num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.ov_num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.ov_num_streams > 0) {
properties["NUM_STREAMS"] = option_.ov_num_streams;
}
compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
@@ -185,10 +195,6 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
@@ -238,8 +244,21 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
output_infos_.push_back(iter->second);
}
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
if (option_.ov_num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.ov_num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.ov_num_streams > 0) {
properties["NUM_STREAMS"] = option_.ov_num_streams;
}
compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
}
@@ -281,4 +300,14 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
return true;
}
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void *stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<OpenVINOBackend>();
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
casted_backend->option_ = option_;
casted_backend->request_ = compiled_model_.create_infer_request();
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end());
return new_backend;
}
} // namespace fastdeploy

View File

@@ -20,17 +20,20 @@
#include <vector>
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/utils/unique_ptr.h"
#include "openvino/openvino.hpp"
namespace fastdeploy {
struct OpenVINOBackendOption {
int cpu_thread_num = 8;
int cpu_thread_num = -1;
int ov_num_streams = 1;
std::map<std::string, std::vector<int64_t>> shape_infos;
};
class OpenVINOBackend : public BaseBackend {
public:
static ov::Core core_;
OpenVINOBackend() {}
virtual ~OpenVINOBackend() = default;
@@ -54,10 +57,13 @@ class OpenVINOBackend : public BaseBackend {
std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override;
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
int device_id = -1) override;
private:
void InitTensorInfo(const std::vector<ov::Output<ov::Node>>& ov_outputs,
std::map<std::string, TensorInfo>* tensor_infos);
ov::Core core_;
ov::CompiledModel compiled_model_;
ov::InferRequest request_;
OpenVINOBackendOption option_;

View File

@@ -216,6 +216,30 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return true;
}
std::unique_ptr<BaseBackend> PaddleBackend::Clone(void *stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
if(device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
auto clone_option = option_;
clone_option.gpu_id = device_id;
clone_option.external_stream_ = stream;
casted_backend->InitFromPaddle(clone_option.model_file,
clone_option.params_file,
clone_option);
FDWARNING << "The target device id:"
<< device_id
<< " is different from current device id:"
<< option_.gpu_id
<< ", cannot share memory with current engine."
<< std::endl;
return new_backend;
}
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
casted_backend->predictor_ = std::move(predictor_->Clone(stream));
return new_backend;
}
#ifdef ENABLE_TRT_BACKEND
void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) {
std::map<std::string, std::vector<int>> max_shape;

View File

@@ -24,6 +24,7 @@
#include "paddle2onnx/converter.h"
#endif
#include "paddle_inference_api.h" // NOLINT
#include "fastdeploy/utils/unique_ptr.h"
#ifdef ENABLE_TRT_BACKEND
#include "fastdeploy/backends/tensorrt/trt_backend.h"
@@ -43,6 +44,9 @@ struct IpuOption {
};
struct PaddleBackendOption {
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
#ifdef WITH_GPU
bool use_gpu = true;
#else
@@ -110,6 +114,9 @@ class PaddleBackend : public BaseBackend {
int NumOutputs() const override { return outputs_desc_.size(); }
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
int device_id = -1) override;
TensorInfo GetInputInfo(int index) override;
TensorInfo GetOutputInfo(int index) override;
std::vector<TensorInfo> GetInputInfos() override;

View File

@@ -285,6 +285,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine();
}
cudaSetDevice(option_.gpu_id);
SetInputs(inputs);
AllocateOutputsBuffer(outputs);
@@ -356,13 +357,17 @@ void TrtBackend::GetInputOutputInfo() {
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
casted_output_tensors_[name] = FDTensor();
}
io_name_index_[name] = i;
}
bindings_.resize(num_binds);
}
void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
for (const auto& item : inputs) {
auto idx = engine_->getBindingIndex(item.name.c_str());
// auto idx = engine_->getBindingIndex(item.name.c_str());
auto iter = io_name_index_.find(item.name);
FDASSERT(iter != io_name_index_.end(), "TRTBackend SetInputs not find name:%s", item.name.c_str());
auto idx = iter->second;
std::vector<int> shape(item.shape.begin(), item.shape.end());
auto dims = ToDims(shape);
context_->setBindingDimensions(idx, dims);
@@ -410,7 +415,10 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
outputs->resize(outputs_desc_.size());
}
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
// auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
auto idx_iter = io_name_index_.find(outputs_desc_[i].name);
FDASSERT(idx_iter != io_name_index_.end(), "TRTBackend Outputs not find name:%s", outputs_desc_[i].name.c_str());
auto idx = idx_iter->second;
auto output_dims = context_->getBindingDimensions(idx);
// find the original index of output
@@ -673,4 +681,47 @@ std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
return infos;
}
std::unique_ptr<BaseBackend> TrtBackend::Clone(void *stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<TrtBackend>();
auto casted_backend = dynamic_cast<TrtBackend*>(new_backend.get());
if(device_id > 0 && device_id != option_.gpu_id) {
auto clone_option = option_;
clone_option.gpu_id = device_id;
clone_option.external_stream_ = stream;
if (option_.model_format == ModelFormat::ONNX) {
FDASSERT(casted_backend->InitFromOnnx(option_.model_file, clone_option),
"Clone model from ONNX failed while initialize TrtBackend.");
} else {
FDASSERT(casted_backend->InitFromPaddle(option_.model_file,
option_.params_file, clone_option),
"Clone model from Paddle failed while initialize TrtBackend.");
}
FDWARNING << "The target device id:"
<< device_id
<< " is different from current device id:"
<< option_.gpu_id
<< ", cannot share memory with current engine."
<< std::endl;
return new_backend;
}
cudaSetDevice(option_.gpu_id);
casted_backend->option_.gpu_id = option_.gpu_id;
if (stream) {
casted_backend->stream_ = reinterpret_cast<cudaStream_t>(stream);
} else {
FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0,
"[ERROR] Error occurs while clone calling cudaStreamCreate().");
}
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
casted_backend->outputs_order_.insert(outputs_order_.begin(), outputs_order_.end());
casted_backend->shape_range_info_.insert(shape_range_info_.begin(), shape_range_info_.end());
casted_backend->engine_ = engine_;
casted_backend->context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
casted_backend->engine_->createExecutionContext());
casted_backend->GetInputOutputInfo();
FDINFO << "TRTBackend clone finish." << std::endl;
return new_backend;
}
} // namespace fastdeploy

View File

@@ -25,6 +25,7 @@
#include "NvOnnxParser.h"
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/backends/tensorrt/utils.h"
#include "fastdeploy/utils/unique_ptr.h"
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
@@ -45,7 +46,7 @@ class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
void writeCalibrationCache(const void* cache,
size_t length) noexcept override {
std::cout << "NOT IMPLEMENT." << std::endl;
fastdeploy::FDERROR << "NOT IMPLEMENT." << std::endl;
}
private:
@@ -62,6 +63,11 @@ struct TrtValueInfo {
};
struct TrtBackendOption {
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
// format of input model
ModelFormat model_format = ModelFormat::AUTOREC;
int gpu_id = 0;
bool enable_fp16 = false;
bool enable_int8 = false;
@@ -99,6 +105,8 @@ class TrtBackend : public BaseBackend {
TensorInfo GetOutputInfo(int index);
std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override;
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
int device_id = -1) override;
~TrtBackend() {
if (parser_) {
@@ -119,6 +127,7 @@ class TrtBackend : public BaseBackend {
std::vector<TrtValueInfo> outputs_desc_;
std::map<std::string, FDDeviceBuffer> inputs_device_buffer_;
std::map<std::string, FDDeviceBuffer> outputs_device_buffer_;
std::map<std::string, int> io_name_index_;
std::string calibration_str_;

View File

@@ -182,4 +182,31 @@ const FDDataType TypeToDataType<uint8_t>::dtype = UINT8;
template <>
const FDDataType TypeToDataType<int8_t>::dtype = INT8;
std::string Str(const ModelFormat& f) {
if (f == ModelFormat::PADDLE) {
return "ModelFormat::PADDLE";
} else if (f == ModelFormat::ONNX) {
return "ModelFormat::ONNX";
}else if (f == ModelFormat::RKNN) {
return "ModelFormat::RKNN";
} else if (f == ModelFormat::TORCHSCRIPT) {
return "ModelFormat::TORCHSCRIPT";
}
return "UNKNOWN-ModelFormat";
}
std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
if (format == ModelFormat::PADDLE) {
out << "ModelFormat::PADDLE";
} else if (format == ModelFormat::ONNX) {
out << "ModelFormat::ONNX";
} else if (format == ModelFormat::RKNN) {
out << "ModelFormat::RKNN";
} else if (format == ModelFormat::TORCHSCRIPT) {
out << "ModelFormat::TORCHSCRIPT";
}
out << "UNKNOWN-ModelFormat";
return out;
}
} // namespace fastdeploy

View File

@@ -65,4 +65,16 @@ struct FASTDEPLOY_DECL TypeToDataType {
static const FDDataType dtype;
};
/*! Deep learning model format */
enum ModelFormat {
AUTOREC, ///< Auto recognize the model format by model file name
PADDLE, ///< Model with paddlepaddle format
ONNX, ///< Model with ONNX format
RKNN, ///< Model with RKNN format
TORCHSCRIPT, ///< Model with TorchScript format
};
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
const ModelFormat& format);
} // namespace fastdeploy

View File

@@ -102,19 +102,6 @@ std::string Str(const Backend& b) {
return "UNKNOWN-Backend";
}
std::string Str(const ModelFormat& f) {
if (f == ModelFormat::PADDLE) {
return "ModelFormat::PADDLE";
} else if (f == ModelFormat::ONNX) {
return "ModelFormat::ONNX";
}else if (f == ModelFormat::RKNN) {
return "ModelFormat::RKNN";
} else if (f == ModelFormat::TORCHSCRIPT) {
return "ModelFormat::TORCHSCRIPT";
}
return "UNKNOWN-ModelFormat";
}
std::ostream& operator<<(std::ostream& out, const Backend& backend) {
if (backend == Backend::ORT) {
out << "Backend::ORT";
@@ -135,20 +122,6 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
return out;
}
std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
if (format == ModelFormat::PADDLE) {
out << "ModelFormat::PADDLE";
} else if (format == ModelFormat::ONNX) {
out << "ModelFormat::ONNX";
} else if (format == ModelFormat::RKNN) {
out << "ModelFormat::RKNN";
} else if (format == ModelFormat::TORCHSCRIPT) {
out << "ModelFormat::TORCHSCRIPT";
}
out << "UNKNOWN-ModelFormat";
return out;
}
bool CheckModelFormat(const std::string& model_file,
const ModelFormat& model_format) {
if (model_format == ModelFormat::PADDLE) {
@@ -411,6 +384,10 @@ void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path;
}
void RuntimeOption::SetOpenVINOStreams(int num_streams) {
ov_num_streams = num_streams;
}
bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
const RuntimeOption& _option) {
#ifdef ENABLE_POROS_BACKEND
@@ -582,6 +559,8 @@ bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
void Runtime::CreatePaddleBackend() {
#ifdef ENABLE_PADDLE_BACKEND
auto pd_option = PaddleBackendOption();
pd_option.model_file = option.model_file;
pd_option.params_file = option.params_file;
pd_option.enable_mkldnn = option.pd_enable_mkldnn;
pd_option.enable_log_info = option.pd_enable_log_info;
pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size;
@@ -642,6 +621,7 @@ void Runtime::CreateOpenVINOBackend() {
#ifdef ENABLE_OPENVINO_BACKEND
auto ov_option = OpenVINOBackendOption();
ov_option.cpu_thread_num = option.cpu_thread_num;
ov_option.ov_num_streams = option.ov_num_streams;
FDASSERT(option.model_format == ModelFormat::PADDLE ||
option.model_format == ModelFormat::ONNX,
"OpenVINOBackend only support model format of ModelFormat::PADDLE / "
@@ -699,6 +679,9 @@ void Runtime::CreateOrtBackend() {
void Runtime::CreateTrtBackend() {
#ifdef ENABLE_TRT_BACKEND
auto trt_option = TrtBackendOption();
trt_option.model_file = option.model_file;
trt_option.params_file = option.params_file;
trt_option.model_format = option.model_format;
trt_option.gpu_id = option.device_id;
trt_option.enable_fp16 = option.trt_enable_fp16;
trt_option.enable_int8 = option.trt_enable_int8;
@@ -771,4 +754,26 @@ void Runtime::CreateRKNPU2Backend() {
#endif
}
Runtime* Runtime::Clone(void* stream, int device_id) {
Runtime* runtime = new Runtime();
if (option.backend != Backend::OPENVINO
&& option.backend != Backend::PDINFER
&& option.backend != Backend::TRT
) {
runtime->Init(option);
FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \
clone engine to reduce CPU/GPU memory usage now. For "
<< option.backend
<< ", FastDeploy will create a new engine which \
will not share memory with the current runtime."
<< std::endl;
return runtime;
}
FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " << Str(option.device)
<< "." << std::endl;
runtime->option = option;
runtime->backend_ = backend_->Clone(stream, device_id);
return runtime;
}
} // namespace fastdeploy

View File

@@ -45,19 +45,8 @@ enum Backend {
RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only
};
/*! Deep learning model format */
enum ModelFormat {
AUTOREC, ///< Auto recognize the model format by model file name
PADDLE, ///< Model with paddlepaddle format
ONNX, ///< Model with ONNX format
RKNN, ///< Model with RKNN format
TORCHSCRIPT, ///< Model with TorchScript format
};
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
const Backend& backend);
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
const ModelFormat& format);
/*! Paddle Lite power mode for mobile device. */
enum LitePowerMode {
@@ -105,8 +94,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Use Nvidia GPU to inference
void UseGpu(int gpu_id = 0);
void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name = fastdeploy::rknpu2::CpuName::RK3588,
fastdeploy::rknpu2::CoreMask rknpu2_core = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0);
void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name
= fastdeploy::rknpu2::CpuName::RK3588,
fastdeploy::rknpu2::CoreMask rknpu2_core
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0);
void SetExternalStream(void* external_stream);
@@ -242,6 +233,11 @@ struct FASTDEPLOY_DECL RuntimeOption {
*/
void DisablePaddleTrtCollectShape();
/*
* @brief Set number of streams by the OpenVINO backends
*/
void SetOpenVINOStreams(int num_streams);
/** \Use Graphcore IPU to inference.
*
* \param[in] device_num the number of IPUs.
@@ -331,13 +327,19 @@ struct FASTDEPLOY_DECL RuntimeOption {
int unconst_ops_thres = -1;
std::string poros_file = "";
// ======Only for OpenVINO Backend=======
int ov_num_streams = 1;
// ======Only for RKNPU2 Backend=======
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ = fastdeploy::rknpu2::CpuName::RK3588;
fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO;
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_
= fastdeploy::rknpu2::CpuName::RK3588;
fastdeploy::rknpu2::CoreMask rknpu2_core_mask_
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO;
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
ModelFormat model_format = ModelFormat::AUTOREC; // format of input model
// format of input model
ModelFormat model_format = ModelFormat::AUTOREC;
};
/*! @brief Runtime object used to inference the loaded model on different devices
@@ -384,6 +386,14 @@ struct FASTDEPLOY_DECL Runtime {
*/
std::vector<TensorInfo> GetOutputInfos();
/** \brief Clone new Runtime when multiple instances of the same model are created
*
* \param[in] stream CUDA Stream, defualt param is nullptr
* \return new Runtime* by this clone
*/
Runtime* Clone(void* stream = nullptr,
int device_id = -1);
RuntimeOption option;
private:

View File

@@ -142,8 +142,10 @@ optimization {
cpu_execution_accelerator : [
{
name : "openvino"
# 设置推理并行计算线程数为4
# 设置推理并行计算线程数为4(所有实例总共线程数)
parameters { key: "cpu_threads" value: "4" }
# 设置OpenVINO的num_streams一般设置为跟实例数一致
parameters { key: "num_streams" value: "1" }
}
]
}

View File

@@ -91,6 +91,9 @@ class ModelState : public BackendModel {
// Runtime options used when creating a FastDeploy Runtime.
std::unique_ptr<fastdeploy::RuntimeOption> runtime_options_;
bool model_load_;
fastdeploy::Runtime* main_runtime_;
bool is_clone_ = true;
// model_outputs is a map that contains unique outputs that the model must
// provide. In the model configuration, the output in the state configuration
@@ -165,7 +168,7 @@ TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model,
}
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
: BackendModel(triton_model) {
: BackendModel(triton_model), model_load_(false), main_runtime_(nullptr), is_clone_(true) {
// Create runtime options that will be cloned and used for each
// instance when creating that instance's runtime.
runtime_options_.reset(new fastdeploy::RuntimeOption());
@@ -218,19 +221,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
THROW_IF_BACKEND_MODEL_ERROR(
ParseIntValue(value_string, &cpu_thread_num));
runtime_options_->SetCpuThreadNum(cpu_thread_num);
// } else if (param_key == "graph_level") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string, &runtime_options_->ort_graph_opt_level));
// } else if (param_key == "inter_op_num_threads") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string,
// &runtime_options_->ort_inter_op_num_threads));
// } else if (param_key == "execution_mode") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string, &runtime_options_->ort_execution_mode));
// } else if (param_key == "capacity") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string, &runtime_options_->pd_mkldnn_cache_size));
} else if (param_key == "use_mkldnn") {
bool pd_enable_mkldnn;
THROW_IF_BACKEND_MODEL_ERROR(
@@ -238,8 +228,16 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn);
} else if (param_key == "use_paddle_log") {
runtime_options_->EnablePaddleLogInfo();
} else if (param_key == "num_streams") {
int num_streams;
THROW_IF_BACKEND_MODEL_ERROR(
ParseIntValue(value_string, &num_streams));
runtime_options_->SetOpenVINOStreams(num_streams);
} else if (param_key == "is_clone") {
THROW_IF_BACKEND_MODEL_ERROR(
ParseBoolValue(value_string, &is_clone_));
} else if (param_key == "use_ipu") {
runtime_options_->UseIpu();
// runtime_options_->UseIpu();
}
}
}
@@ -290,17 +288,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
std::string value_string;
THROW_IF_BACKEND_MODEL_ERROR(
params.MemberAsString(param_key.c_str(), &value_string));
// if (param_key == "graph_level") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string, &runtime_options_->ort_graph_opt_level));
// } else if (param_key == "inter_op_num_threads") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string,
// &runtime_options_->ort_inter_op_num_threads));
// } else if (param_key == "execution_mode") {
// THROW_IF_BACKEND_MODEL_ERROR(ParseIntValue(
// value_string, &runtime_options_->ort_execution_mode));
// }
if (param_key == "precision") {
std::transform(value_string.begin(), value_string.end(),
value_string.begin(), ::tolower);
@@ -325,6 +312,9 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
runtime_options_->EnablePaddleToTrt();
} else if (param_key == "use_paddle_log") {
runtime_options_->EnablePaddleLogInfo();
} else if (param_key == "is_clone") {
THROW_IF_BACKEND_MODEL_ERROR(
ParseBoolValue(value_string, &is_clone_));
}
}
}
@@ -340,6 +330,20 @@ TRITONSERVER_Error* ModelState::LoadModel(
const int32_t instance_group_device_id, std::string* model_path,
std::string* params_path, fastdeploy::Runtime** runtime,
cudaStream_t stream) {
// FastDeploy Runtime creation is not thread-safe, so multiple creations
// are serialized with a global lock.
// The Clone interface can be invoked only when the main_runtime_ is created.
static std::mutex global_context_mu;
std::lock_guard<std::mutex> glock(global_context_mu);
if(model_load_ && is_clone_) {
if(main_runtime_ == nullptr) {
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND,
std::string("main_runtime is nullptr").c_str());
}
*runtime = main_runtime_->Clone((void*)stream, instance_group_device_id);
} else {
auto dir_path = JoinPath({RepositoryPath(), std::to_string(Version())});
{
// ONNX Format
@@ -382,7 +386,7 @@ TRITONSERVER_Error* ModelState::LoadModel(
(instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) {
runtime_options_->UseGpu(instance_group_device_id);
runtime_options_->SetExternalStream((void*)stream);
} else {
} else if (runtime_options_->device != fastdeploy::Device::IPU) {
runtime_options_->UseCpu();
}
#else
@@ -392,12 +396,13 @@ TRITONSERVER_Error* ModelState::LoadModel(
}
#endif // TRITON_ENABLE_GPU
*runtime = new fastdeploy::Runtime();
*runtime = main_runtime_ = new fastdeploy::Runtime();
if (!(*runtime)->Init(*runtime_options_)) {
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_NOT_FOUND,
std::string("Runtime init error").c_str());
}
model_load_ = true;
}
return nullptr; // success
}