[Backend] Support Intel GPU with heterogeneous mode (#701)

* Add some comments for python api

* support openvino gpu

* Add cpu operators

* add interface to specify hetero operators

* remove useless dir

* format code

* remove debug code

* Support GPU for ONNX
This commit is contained in:
Jason
2022-11-25 19:40:56 +08:00
committed by GitHub
parent ad5c9c08b2
commit 5b3fd9dd88
7 changed files with 156 additions and 44 deletions

View File

@@ -1,8 +0,0 @@
{
"image": "mcr.microsoft.com/devcontainers/universal:2",
"features": {
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
"version": "latest"
}
}
}

View File

@@ -32,6 +32,14 @@ std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
return res;
}
ov::PartialShape VecToPartialShape(const std::vector<int64_t>& shape) {
std::vector<ov::Dimension> dims;
for (size_t i = 0; i < shape.size(); ++i) {
dims.emplace_back(ov::Dimension(shape[i]));
}
return ov::PartialShape(dims);
}
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
if (type == ov::element::f32) {
return FDDataType::FP32;
@@ -100,6 +108,26 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
option_ = option;
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
if (option_.shape_infos.size() > 0) {
std::map<std::string, ov::PartialShape> shape_infos;
for (const auto& item : option_.shape_infos) {
shape_infos[item.first] = VecToPartialShape(item.second);
}
model->reshape(shape_infos);
}
if (option_.device.find("HETERO") != std::string::npos) {
auto supported_ops = core_.query_model(model, option_.device);
for (auto&& op : model->get_ops()) {
auto& affinity = supported_ops[op->get_friendly_name()];
if (option_.cpu_operators.find(op->description()) !=
option_.cpu_operators.end()) {
op->get_rt_info()["affinity"] = "CPU";
} else {
op->get_rt_info()["affinity"] = affinity;
}
}
}
// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
@@ -151,14 +179,25 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
if (option_.ov_num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.ov_num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.ov_num_streams > 0) {
properties["NUM_STREAMS"] = option_.ov_num_streams;
if (option_.device == "CPU") {
if (option_.num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.num_streams > 0) {
properties["NUM_STREAMS"] = option_.num_streams;
}
} else {
if (option_.num_streams != 0) {
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
"device is set as "
<< option_.device << ", the NUM_STREAMS will be ignored."
<< std::endl;
}
}
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
<< std::endl;
compiled_model_ = core_.compile_model(model, option.device, properties);
request_ = compiled_model_.create_infer_request();
@@ -199,6 +238,27 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
if (option_.shape_infos.size() > 0) {
std::map<std::string, ov::PartialShape> shape_infos;
for (const auto& item : option_.shape_infos) {
shape_infos[item.first] = VecToPartialShape(item.second);
}
model->reshape(shape_infos);
}
if (option_.device.find("HETERO") != std::string::npos) {
auto supported_ops = core_.query_model(model, option_.device);
for (auto&& op : model->get_ops()) {
auto& affinity = supported_ops[op->get_friendly_name()];
if (option_.cpu_operators.find(op->description()) !=
option_.cpu_operators.end()) {
op->get_rt_info()["affinity"] = "CPU";
} else {
op->get_rt_info()["affinity"] = affinity;
}
}
}
// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
std::map<std::string, TensorInfo> input_infos;
@@ -249,14 +309,25 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
if (option_.ov_num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.ov_num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.ov_num_streams > 0) {
properties["NUM_STREAMS"] = option_.ov_num_streams;
if (option_.device == "CPU") {
if (option_.num_streams == -1) {
properties["NUM_STREAMS"] = ov::streams::AUTO;
} else if (option_.num_streams == -2) {
properties["NUM_STREAMS"] = ov::streams::NUMA;
} else if (option_.num_streams > 0) {
properties["NUM_STREAMS"] = option_.num_streams;
}
} else {
if (option_.num_streams != 0) {
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
"device is set as "
<< option_.device << ", the NUM_STREAMS will be ignored."
<< std::endl;
}
}
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
<< std::endl;
compiled_model_ = core_.compile_model(model, option.device, properties);
request_ = compiled_model_.create_infer_request();
@@ -302,13 +373,16 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
return true;
}
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void *stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<OpenVINOBackend>();
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void* stream,
int device_id) {
std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<OpenVINOBackend>();
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
casted_backend->option_ = option_;
casted_backend->request_ = compiled_model_.create_infer_request();
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end());
casted_backend->output_infos_.assign(output_infos_.begin(),
output_infos_.end());
return new_backend;
}

View File

@@ -28,8 +28,9 @@ namespace fastdeploy {
struct OpenVINOBackendOption {
std::string device = "CPU";
int cpu_thread_num = -1;
int ov_num_streams = 1;
int num_streams = 0;
std::map<std::string, std::vector<int64_t>> shape_infos;
std::set<std::string> cpu_operators{"MulticlassNms"};
};
class OpenVINOBackend : public BaseBackend {
@@ -38,13 +39,13 @@ class OpenVINOBackend : public BaseBackend {
OpenVINOBackend() {}
virtual ~OpenVINOBackend() = default;
bool InitFromPaddle(
const std::string& model_file, const std::string& params_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool
InitFromPaddle(const std::string& model_file, const std::string& params_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool InitFromOnnx(
const std::string& model_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool
InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) override;
@@ -58,7 +59,7 @@ class OpenVINOBackend : public BaseBackend {
std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override;
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
int device_id = -1) override;
private:
@@ -71,4 +72,5 @@ class OpenVINOBackend : public BaseBackend {
std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_;
};
} // namespace fastdeploy

View File

@@ -34,6 +34,8 @@ void BindRuntime(pybind11::module& m) {
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size",

View File

@@ -646,7 +646,11 @@ void Runtime::CreateOpenVINOBackend() {
auto ov_option = OpenVINOBackendOption();
ov_option.cpu_thread_num = option.cpu_thread_num;
ov_option.device = option.openvino_device;
ov_option.ov_num_streams = option.ov_num_streams;
ov_option.shape_infos = option.ov_shape_infos;
ov_option.num_streams = option.ov_num_streams;
for (const auto& op : option.ov_cpu_operators) {
ov_option.cpu_operators.insert(op);
}
FDASSERT(option.model_format == ModelFormat::PADDLE ||
option.model_format == ModelFormat::ONNX,
"OpenVINOBackend only support model format of ModelFormat::PADDLE / "

View File

@@ -171,7 +171,22 @@ struct FASTDEPLOY_DECL RuntimeOption {
/**
* @brief Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'....
*/
void SetOpenVINODevice(const std::string& name = "CPU");
void SetOpenVINODevice(const std::string& name = "CPU");
/**
* @brief Set shape info for OpenVINO
*/
void SetOpenVINOShapeInfo(
const std::map<std::string, std::vector<int64_t>>& shape_info) {
ov_shape_infos = shape_info;
}
/**
* @brief While use OpenVINO backend with intel GPU, use this interface to specify operators run on CPU
*/
void SetOpenVINOCpuOperators(const std::vector<std::string>& operators) {
ov_cpu_operators = operators;
}
/**
* @brief Set optimzed model dir for Paddle Lite backend.
@@ -349,9 +364,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
size_t trt_max_batch_size = 32;
size_t trt_max_workspace_size = 1 << 30;
// ======Only for OpenVINO Backend======
std::string openvino_device = "CPU";
// ======Only for Poros Backend=======
bool is_dynamic = false;
bool long_to_int = true;
@@ -360,7 +372,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
std::string poros_file = "";
// ======Only for OpenVINO Backend=======
int ov_num_streams = 1;
int ov_num_streams = 0;
std::string openvino_device = "CPU";
std::map<std::string, std::vector<int64_t>> ov_shape_infos;
std::vector<std::string> ov_cpu_operators;
// ======Only for RKNPU2 Backend=======
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_

View File

@@ -35,7 +35,7 @@ class Runtime:
self.runtime_option._option), "Initialize Runtime Failed!"
def forward(self, *inputs):
"""Inference with input data for poros
"""[Only for Poros backend] Inference with input data for poros
:param data: (list[str : numpy.ndarray])The input data list
:return list of numpy.ndarray
@@ -60,7 +60,7 @@ class Runtime:
return self._runtime.infer(data)
def compile(self, warm_datas):
"""compile with prewarm data for poros
"""[Only for Poros backend] compile with prewarm data for poros
:param data: (list[str : numpy.ndarray])The prewarm data list
:return TorchScript Model
@@ -122,6 +122,9 @@ class RuntimeOption:
"""
def __init__(self):
"""Initialize a FastDeploy RuntimeOption object.
"""
self._option = C.RuntimeOption()
@property
@@ -210,8 +213,6 @@ class RuntimeOption:
def use_rknpu2(self,
rknpu2_name=rknpu2.CpuName.RK3588,
rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0):
"""Inference with CPU
"""
return self._option.use_rknpu2(rknpu2_name, rknpu2_core)
def set_cpu_thread_num(self, thread_num=-1):
@@ -222,6 +223,10 @@ class RuntimeOption:
return self._option.set_cpu_thread_num(thread_num)
def set_ort_graph_opt_level(self, level=-1):
"""Set graph optimization level for ONNX Runtime backend
:param level: (int)Optimization level, -1 means the default setting
"""
return self._option.set_ort_graph_opt_level(level)
def use_paddle_backend(self):
@@ -274,6 +279,20 @@ class RuntimeOption:
"""
return self._option.set_openvino_device(name)
def set_openvino_shape_info(self, shape_info):
"""Set shape information of the models' inputs, used for GPU to fix the shape
:param shape_info: (dict{str, list of int})Shape information of model's inputs, e.g {"image": [1, 3, 640, 640], "scale_factor": [1, 2]}
"""
return self._option.set_openvino_shape_info(shape_info)
def set_openvino_cpu_operators(self, operators):
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
:param operators: (list of string)list of operators' name, e.g ["MulticlasNms"]
"""
return self._option.set_openvino_cpu_operators(operators)
def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
"""
@@ -367,9 +386,13 @@ class RuntimeOption:
return self._option.set_trt_max_batch_size(trt_max_batch_size)
def enable_paddle_trt_collect_shape(self):
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.enable_paddle_trt_collect_shape()
def disable_paddle_trt_collect_shape(self):
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.disable_paddle_trt_collect_shape()
def use_ipu(self,