[Backend] Support Intel GPU with heterogeneous mode (#701)

* Add some comments for python api

* support openvino gpu

* Add cpu operators

* add interface to specify hetero operators

* remove useless dir

* format code

* remove debug code

* Support GPU for ONNX
This commit is contained in:
Jason
2022-11-25 19:40:56 +08:00
committed by GitHub
parent ad5c9c08b2
commit 5b3fd9dd88
7 changed files with 156 additions and 44 deletions

View File

@@ -1,8 +0,0 @@
{
"image": "mcr.microsoft.com/devcontainers/universal:2",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"version": "latest"
}
}
}

View File

@@ -32,6 +32,14 @@ std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
return res; return res;
} }
ov::PartialShape VecToPartialShape(const std::vector<int64_t>& shape) {
std::vector<ov::Dimension> dims;
for (size_t i = 0; i < shape.size(); ++i) {
dims.emplace_back(ov::Dimension(shape[i]));
}
return ov::PartialShape(dims);
}
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) { FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
if (type == ov::element::f32) { if (type == ov::element::f32) {
return FDDataType::FP32; return FDDataType::FP32;
@@ -100,6 +108,26 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
option_ = option; option_ = option;
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file); std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
if (option_.shape_infos.size() > 0) {
std::map<std::string, ov::PartialShape> shape_infos;
for (const auto& item : option_.shape_infos) {
shape_infos[item.first] = VecToPartialShape(item.second);
}
model->reshape(shape_infos);
}
if (option_.device.find("HETERO") != std::string::npos) {
auto supported_ops = core_.query_model(model, option_.device);
for (auto&& op : model->get_ops()) {
auto& affinity = supported_ops[op->get_friendly_name()];
if (option_.cpu_operators.find(op->description()) !=
option_.cpu_operators.end()) {
op->get_rt_info()["affinity"] = "CPU";
} else {
op->get_rt_info()["affinity"] = affinity;
}
}
}
// Get inputs/outputs information from loaded model // Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs(); const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
@@ -151,14 +179,25 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
if (option_.cpu_thread_num > 0) { if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
} }
if (option_.ov_num_streams == -1) { if (option_.device == "CPU") {
properties["NUM_STREAMS"] = ov::streams::AUTO; if (option_.num_streams == -1) {
} else if (option_.ov_num_streams == -2) { properties["NUM_STREAMS"] = ov::streams::AUTO;
properties["NUM_STREAMS"] = ov::streams::NUMA; } else if (option_.num_streams == -2) {
} else if (option_.ov_num_streams > 0) { properties["NUM_STREAMS"] = ov::streams::NUMA;
properties["NUM_STREAMS"] = option_.ov_num_streams; } else if (option_.num_streams > 0) {
properties["NUM_STREAMS"] = option_.num_streams;
}
} else {
if (option_.num_streams != 0) {
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
"device is set as "
<< option_.device << ", the NUM_STREAMS will be ignored."
<< std::endl;
}
} }
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
<< std::endl;
compiled_model_ = core_.compile_model(model, option.device, properties); compiled_model_ = core_.compile_model(model, option.device, properties);
request_ = compiled_model_.create_infer_request(); request_ = compiled_model_.create_infer_request();
@@ -199,6 +238,27 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
std::shared_ptr<ov::Model> model = core_.read_model(model_file); std::shared_ptr<ov::Model> model = core_.read_model(model_file);
if (option_.shape_infos.size() > 0) {
std::map<std::string, ov::PartialShape> shape_infos;
for (const auto& item : option_.shape_infos) {
shape_infos[item.first] = VecToPartialShape(item.second);
}
model->reshape(shape_infos);
}
if (option_.device.find("HETERO") != std::string::npos) {
auto supported_ops = core_.query_model(model, option_.device);
for (auto&& op : model->get_ops()) {
auto& affinity = supported_ops[op->get_friendly_name()];
if (option_.cpu_operators.find(op->description()) !=
option_.cpu_operators.end()) {
op->get_rt_info()["affinity"] = "CPU";
} else {
op->get_rt_info()["affinity"] = affinity;
}
}
}
// Get inputs/outputs information from loaded model // Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs(); const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
std::map<std::string, TensorInfo> input_infos; std::map<std::string, TensorInfo> input_infos;
@@ -249,18 +309,29 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
if (option_.cpu_thread_num > 0) { if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
} }
if (option_.ov_num_streams == -1) { if (option_.device == "CPU") {
properties["NUM_STREAMS"] = ov::streams::AUTO; if (option_.num_streams == -1) {
} else if (option_.ov_num_streams == -2) { properties["NUM_STREAMS"] = ov::streams::AUTO;
properties["NUM_STREAMS"] = ov::streams::NUMA; } else if (option_.num_streams == -2) {
} else if (option_.ov_num_streams > 0) { properties["NUM_STREAMS"] = ov::streams::NUMA;
properties["NUM_STREAMS"] = option_.ov_num_streams; } else if (option_.num_streams > 0) {
properties["NUM_STREAMS"] = option_.num_streams;
}
} else {
if (option_.num_streams != 0) {
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
"device is set as "
<< option_.device << ", the NUM_STREAMS will be ignored."
<< std::endl;
}
} }
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
<< std::endl;
compiled_model_ = core_.compile_model(model, option.device, properties); compiled_model_ = core_.compile_model(model, option.device, properties);
request_ = compiled_model_.create_infer_request(); request_ = compiled_model_.create_infer_request();
initialized_ = true; initialized_ = true;
return true; return true;
} }
@@ -302,13 +373,16 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
return true; return true;
} }
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void *stream, int device_id) { std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void* stream,
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<OpenVINOBackend>(); int device_id) {
std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<OpenVINOBackend>();
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get()); auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
casted_backend->option_ = option_; casted_backend->option_ = option_;
casted_backend->request_ = compiled_model_.create_infer_request(); casted_backend->request_ = compiled_model_.create_infer_request();
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end()); casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end()); casted_backend->output_infos_.assign(output_infos_.begin(),
output_infos_.end());
return new_backend; return new_backend;
} }

View File

@@ -28,8 +28,9 @@ namespace fastdeploy {
struct OpenVINOBackendOption { struct OpenVINOBackendOption {
std::string device = "CPU"; std::string device = "CPU";
int cpu_thread_num = -1; int cpu_thread_num = -1;
int ov_num_streams = 1; int num_streams = 0;
std::map<std::string, std::vector<int64_t>> shape_infos; std::map<std::string, std::vector<int64_t>> shape_infos;
std::set<std::string> cpu_operators{"MulticlassNms"};
}; };
class OpenVINOBackend : public BaseBackend { class OpenVINOBackend : public BaseBackend {
@@ -38,13 +39,13 @@ class OpenVINOBackend : public BaseBackend {
OpenVINOBackend() {} OpenVINOBackend() {}
virtual ~OpenVINOBackend() = default; virtual ~OpenVINOBackend() = default;
bool InitFromPaddle( bool
const std::string& model_file, const std::string& params_file, InitFromPaddle(const std::string& model_file, const std::string& params_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption()); const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool InitFromOnnx( bool
const std::string& model_file, InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption()); const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool Infer(std::vector<FDTensor>& inputs, bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) override; std::vector<FDTensor>* outputs) override;
@@ -58,7 +59,7 @@ class OpenVINOBackend : public BaseBackend {
std::vector<TensorInfo> GetInputInfos() override; std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override; std::vector<TensorInfo> GetOutputInfos() override;
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr, std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
int device_id = -1) override; int device_id = -1) override;
private: private:
@@ -71,4 +72,5 @@ class OpenVINOBackend : public BaseBackend {
std::vector<TensorInfo> input_infos_; std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_; std::vector<TensorInfo> output_infos_;
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -34,6 +34,8 @@ void BindRuntime(pybind11::module& m) {
.def("use_lite_backend", &RuntimeOption::UseLiteBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend)
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size", .def("set_paddle_mkldnn_cache_size",

View File

@@ -646,7 +646,11 @@ void Runtime::CreateOpenVINOBackend() {
auto ov_option = OpenVINOBackendOption(); auto ov_option = OpenVINOBackendOption();
ov_option.cpu_thread_num = option.cpu_thread_num; ov_option.cpu_thread_num = option.cpu_thread_num;
ov_option.device = option.openvino_device; ov_option.device = option.openvino_device;
ov_option.ov_num_streams = option.ov_num_streams; ov_option.shape_infos = option.ov_shape_infos;
ov_option.num_streams = option.ov_num_streams;
for (const auto& op : option.ov_cpu_operators) {
ov_option.cpu_operators.insert(op);
}
FDASSERT(option.model_format == ModelFormat::PADDLE || FDASSERT(option.model_format == ModelFormat::PADDLE ||
option.model_format == ModelFormat::ONNX, option.model_format == ModelFormat::ONNX,
"OpenVINOBackend only support model format of ModelFormat::PADDLE / " "OpenVINOBackend only support model format of ModelFormat::PADDLE / "

View File

@@ -171,7 +171,22 @@ struct FASTDEPLOY_DECL RuntimeOption {
/** /**
* @brief Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'.... * @brief Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'....
*/ */
void SetOpenVINODevice(const std::string& name = "CPU"); void SetOpenVINODevice(const std::string& name = "CPU");
/**
* @brief Set shape info for OpenVINO
*/
void SetOpenVINOShapeInfo(
const std::map<std::string, std::vector<int64_t>>& shape_info) {
ov_shape_infos = shape_info;
}
/**
* @brief While use OpenVINO backend with intel GPU, use this interface to specify operators run on CPU
*/
void SetOpenVINOCpuOperators(const std::vector<std::string>& operators) {
ov_cpu_operators = operators;
}
/** /**
* @brief Set optimzed model dir for Paddle Lite backend. * @brief Set optimzed model dir for Paddle Lite backend.
@@ -349,9 +364,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
size_t trt_max_batch_size = 32; size_t trt_max_batch_size = 32;
size_t trt_max_workspace_size = 1 << 30; size_t trt_max_workspace_size = 1 << 30;
// ======Only for OpenVINO Backend======
std::string openvino_device = "CPU";
// ======Only for Poros Backend======= // ======Only for Poros Backend=======
bool is_dynamic = false; bool is_dynamic = false;
bool long_to_int = true; bool long_to_int = true;
@@ -360,7 +372,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
std::string poros_file = ""; std::string poros_file = "";
// ======Only for OpenVINO Backend======= // ======Only for OpenVINO Backend=======
int ov_num_streams = 1; int ov_num_streams = 0;
std::string openvino_device = "CPU";
std::map<std::string, std::vector<int64_t>> ov_shape_infos;
std::vector<std::string> ov_cpu_operators;
// ======Only for RKNPU2 Backend======= // ======Only for RKNPU2 Backend=======
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ fastdeploy::rknpu2::CpuName rknpu2_cpu_name_

View File

@@ -35,7 +35,7 @@ class Runtime:
self.runtime_option._option), "Initialize Runtime Failed!" self.runtime_option._option), "Initialize Runtime Failed!"
def forward(self, *inputs): def forward(self, *inputs):
"""Inference with input data for poros """[Only for Poros backend] Inference with input data for poros
:param data: (list[str : numpy.ndarray])The input data list :param data: (list[str : numpy.ndarray])The input data list
:return list of numpy.ndarray :return list of numpy.ndarray
@@ -60,7 +60,7 @@ class Runtime:
return self._runtime.infer(data) return self._runtime.infer(data)
def compile(self, warm_datas): def compile(self, warm_datas):
"""compile with prewarm data for poros """[Only for Poros backend] compile with prewarm data for poros
:param data: (list[str : numpy.ndarray])The prewarm data list :param data: (list[str : numpy.ndarray])The prewarm data list
:return TorchScript Model :return TorchScript Model
@@ -122,6 +122,9 @@ class RuntimeOption:
""" """
def __init__(self): def __init__(self):
"""Initialize a FastDeploy RuntimeOption object.
"""
self._option = C.RuntimeOption() self._option = C.RuntimeOption()
@property @property
@@ -210,8 +213,6 @@ class RuntimeOption:
def use_rknpu2(self, def use_rknpu2(self,
rknpu2_name=rknpu2.CpuName.RK3588, rknpu2_name=rknpu2.CpuName.RK3588,
rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0): rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0):
"""Inference with CPU
"""
return self._option.use_rknpu2(rknpu2_name, rknpu2_core) return self._option.use_rknpu2(rknpu2_name, rknpu2_core)
def set_cpu_thread_num(self, thread_num=-1): def set_cpu_thread_num(self, thread_num=-1):
@@ -222,6 +223,10 @@ class RuntimeOption:
return self._option.set_cpu_thread_num(thread_num) return self._option.set_cpu_thread_num(thread_num)
def set_ort_graph_opt_level(self, level=-1): def set_ort_graph_opt_level(self, level=-1):
"""Set graph optimization level for ONNX Runtime backend
:param level: (int)Optimization level, -1 means the default setting
"""
return self._option.set_ort_graph_opt_level(level) return self._option.set_ort_graph_opt_level(level)
def use_paddle_backend(self): def use_paddle_backend(self):
@@ -274,6 +279,20 @@ class RuntimeOption:
""" """
return self._option.set_openvino_device(name) return self._option.set_openvino_device(name)
def set_openvino_shape_info(self, shape_info):
"""Set shape information of the models' inputs, used for GPU to fix the shape
:param shape_info: (dict{str, list of int})Shape information of model's inputs, e.g {"image": [1, 3, 640, 640], "scale_factor": [1, 2]}
"""
return self._option.set_openvino_shape_info(shape_info)
def set_openvino_cpu_operators(self, operators):
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
:param operators: (list of string)list of operators' name, e.g ["MulticlasNms"]
"""
return self._option.set_openvino_cpu_operators(operators)
def enable_paddle_log_info(self): def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
""" """
@@ -367,9 +386,13 @@ class RuntimeOption:
return self._option.set_trt_max_batch_size(trt_max_batch_size) return self._option.set_trt_max_batch_size(trt_max_batch_size)
def enable_paddle_trt_collect_shape(self): def enable_paddle_trt_collect_shape(self):
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.enable_paddle_trt_collect_shape() return self._option.enable_paddle_trt_collect_shape()
def disable_paddle_trt_collect_shape(self): def disable_paddle_trt_collect_shape(self):
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.disable_paddle_trt_collect_shape() return self._option.disable_paddle_trt_collect_shape()
def use_ipu(self, def use_ipu(self,