mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Backend] Support Intel GPU with heterogeneous mode (#701)
* Add some comments for python api * support openvino gpu * Add cpu operators * add interface to specify hetero operators * remove useless dir * format code * remove debug code * Support GPU for ONNX
This commit is contained in:
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"image": "mcr.microsoft.com/devcontainers/universal:2",
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/nvidia-cuda:1": {
|
||||
"version": "latest"
|
||||
}
|
||||
}
|
||||
}
|
@@ -32,6 +32,14 @@ std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ov::PartialShape VecToPartialShape(const std::vector<int64_t>& shape) {
|
||||
std::vector<ov::Dimension> dims;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
dims.emplace_back(ov::Dimension(shape[i]));
|
||||
}
|
||||
return ov::PartialShape(dims);
|
||||
}
|
||||
|
||||
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
|
||||
if (type == ov::element::f32) {
|
||||
return FDDataType::FP32;
|
||||
@@ -100,6 +108,26 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
|
||||
option_ = option;
|
||||
|
||||
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
|
||||
if (option_.shape_infos.size() > 0) {
|
||||
std::map<std::string, ov::PartialShape> shape_infos;
|
||||
for (const auto& item : option_.shape_infos) {
|
||||
shape_infos[item.first] = VecToPartialShape(item.second);
|
||||
}
|
||||
model->reshape(shape_infos);
|
||||
}
|
||||
|
||||
if (option_.device.find("HETERO") != std::string::npos) {
|
||||
auto supported_ops = core_.query_model(model, option_.device);
|
||||
for (auto&& op : model->get_ops()) {
|
||||
auto& affinity = supported_ops[op->get_friendly_name()];
|
||||
if (option_.cpu_operators.find(op->description()) !=
|
||||
option_.cpu_operators.end()) {
|
||||
op->get_rt_info()["affinity"] = "CPU";
|
||||
} else {
|
||||
op->get_rt_info()["affinity"] = affinity;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get inputs/outputs information from loaded model
|
||||
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
||||
@@ -151,14 +179,25 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
|
||||
if (option_.cpu_thread_num > 0) {
|
||||
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
||||
}
|
||||
if (option_.ov_num_streams == -1) {
|
||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||
} else if (option_.ov_num_streams == -2) {
|
||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||
} else if (option_.ov_num_streams > 0) {
|
||||
properties["NUM_STREAMS"] = option_.ov_num_streams;
|
||||
if (option_.device == "CPU") {
|
||||
if (option_.num_streams == -1) {
|
||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||
} else if (option_.num_streams == -2) {
|
||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||
} else if (option_.num_streams > 0) {
|
||||
properties["NUM_STREAMS"] = option_.num_streams;
|
||||
}
|
||||
} else {
|
||||
if (option_.num_streams != 0) {
|
||||
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
|
||||
"device is set as "
|
||||
<< option_.device << ", the NUM_STREAMS will be ignored."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
|
||||
|
||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
|
||||
<< std::endl;
|
||||
compiled_model_ = core_.compile_model(model, option.device, properties);
|
||||
|
||||
request_ = compiled_model_.create_infer_request();
|
||||
@@ -199,6 +238,27 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
|
||||
|
||||
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
|
||||
|
||||
if (option_.shape_infos.size() > 0) {
|
||||
std::map<std::string, ov::PartialShape> shape_infos;
|
||||
for (const auto& item : option_.shape_infos) {
|
||||
shape_infos[item.first] = VecToPartialShape(item.second);
|
||||
}
|
||||
model->reshape(shape_infos);
|
||||
}
|
||||
|
||||
if (option_.device.find("HETERO") != std::string::npos) {
|
||||
auto supported_ops = core_.query_model(model, option_.device);
|
||||
for (auto&& op : model->get_ops()) {
|
||||
auto& affinity = supported_ops[op->get_friendly_name()];
|
||||
if (option_.cpu_operators.find(op->description()) !=
|
||||
option_.cpu_operators.end()) {
|
||||
op->get_rt_info()["affinity"] = "CPU";
|
||||
} else {
|
||||
op->get_rt_info()["affinity"] = affinity;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get inputs/outputs information from loaded model
|
||||
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
|
||||
std::map<std::string, TensorInfo> input_infos;
|
||||
@@ -249,14 +309,25 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
|
||||
if (option_.cpu_thread_num > 0) {
|
||||
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
|
||||
}
|
||||
if (option_.ov_num_streams == -1) {
|
||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||
} else if (option_.ov_num_streams == -2) {
|
||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||
} else if (option_.ov_num_streams > 0) {
|
||||
properties["NUM_STREAMS"] = option_.ov_num_streams;
|
||||
if (option_.device == "CPU") {
|
||||
if (option_.num_streams == -1) {
|
||||
properties["NUM_STREAMS"] = ov::streams::AUTO;
|
||||
} else if (option_.num_streams == -2) {
|
||||
properties["NUM_STREAMS"] = ov::streams::NUMA;
|
||||
} else if (option_.num_streams > 0) {
|
||||
properties["NUM_STREAMS"] = option_.num_streams;
|
||||
}
|
||||
} else {
|
||||
if (option_.num_streams != 0) {
|
||||
FDWARNING << "NUM_STREAMS only available on device CPU, currently the "
|
||||
"device is set as "
|
||||
<< option_.device << ", the NUM_STREAMS will be ignored."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl;
|
||||
|
||||
FDINFO << "Compile OpenVINO model on device_name:" << option.device << "."
|
||||
<< std::endl;
|
||||
compiled_model_ = core_.compile_model(model, option.device, properties);
|
||||
|
||||
request_ = compiled_model_.create_infer_request();
|
||||
@@ -302,13 +373,16 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void *stream, int device_id) {
|
||||
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<OpenVINOBackend>();
|
||||
std::unique_ptr<BaseBackend> OpenVINOBackend::Clone(void* stream,
|
||||
int device_id) {
|
||||
std::unique_ptr<BaseBackend> new_backend =
|
||||
utils::make_unique<OpenVINOBackend>();
|
||||
auto casted_backend = dynamic_cast<OpenVINOBackend*>(new_backend.get());
|
||||
casted_backend->option_ = option_;
|
||||
casted_backend->request_ = compiled_model_.create_infer_request();
|
||||
casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end());
|
||||
casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end());
|
||||
casted_backend->output_infos_.assign(output_infos_.begin(),
|
||||
output_infos_.end());
|
||||
return new_backend;
|
||||
}
|
||||
|
||||
|
@@ -28,8 +28,9 @@ namespace fastdeploy {
|
||||
struct OpenVINOBackendOption {
|
||||
std::string device = "CPU";
|
||||
int cpu_thread_num = -1;
|
||||
int ov_num_streams = 1;
|
||||
int num_streams = 0;
|
||||
std::map<std::string, std::vector<int64_t>> shape_infos;
|
||||
std::set<std::string> cpu_operators{"MulticlassNms"};
|
||||
};
|
||||
|
||||
class OpenVINOBackend : public BaseBackend {
|
||||
@@ -38,13 +39,13 @@ class OpenVINOBackend : public BaseBackend {
|
||||
OpenVINOBackend() {}
|
||||
virtual ~OpenVINOBackend() = default;
|
||||
|
||||
bool InitFromPaddle(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||
bool
|
||||
InitFromPaddle(const std::string& model_file, const std::string& params_file,
|
||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||
|
||||
bool InitFromOnnx(
|
||||
const std::string& model_file,
|
||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||
bool
|
||||
InitFromOnnx(const std::string& model_file,
|
||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) override;
|
||||
@@ -58,7 +59,7 @@ class OpenVINOBackend : public BaseBackend {
|
||||
std::vector<TensorInfo> GetInputInfos() override;
|
||||
std::vector<TensorInfo> GetOutputInfos() override;
|
||||
|
||||
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
|
||||
std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
|
||||
int device_id = -1) override;
|
||||
|
||||
private:
|
||||
@@ -71,4 +72,5 @@ class OpenVINOBackend : public BaseBackend {
|
||||
std::vector<TensorInfo> input_infos_;
|
||||
std::vector<TensorInfo> output_infos_;
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -34,6 +34,8 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
|
||||
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
||||
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
|
||||
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
|
||||
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators)
|
||||
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
||||
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
||||
.def("set_paddle_mkldnn_cache_size",
|
||||
|
@@ -646,7 +646,11 @@ void Runtime::CreateOpenVINOBackend() {
|
||||
auto ov_option = OpenVINOBackendOption();
|
||||
ov_option.cpu_thread_num = option.cpu_thread_num;
|
||||
ov_option.device = option.openvino_device;
|
||||
ov_option.ov_num_streams = option.ov_num_streams;
|
||||
ov_option.shape_infos = option.ov_shape_infos;
|
||||
ov_option.num_streams = option.ov_num_streams;
|
||||
for (const auto& op : option.ov_cpu_operators) {
|
||||
ov_option.cpu_operators.insert(op);
|
||||
}
|
||||
FDASSERT(option.model_format == ModelFormat::PADDLE ||
|
||||
option.model_format == ModelFormat::ONNX,
|
||||
"OpenVINOBackend only support model format of ModelFormat::PADDLE / "
|
||||
|
@@ -171,7 +171,22 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
/**
|
||||
* @brief Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'....
|
||||
*/
|
||||
void SetOpenVINODevice(const std::string& name = "CPU");
|
||||
void SetOpenVINODevice(const std::string& name = "CPU");
|
||||
|
||||
/**
|
||||
* @brief Set shape info for OpenVINO
|
||||
*/
|
||||
void SetOpenVINOShapeInfo(
|
||||
const std::map<std::string, std::vector<int64_t>>& shape_info) {
|
||||
ov_shape_infos = shape_info;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief While use OpenVINO backend with intel GPU, use this interface to specify operators run on CPU
|
||||
*/
|
||||
void SetOpenVINOCpuOperators(const std::vector<std::string>& operators) {
|
||||
ov_cpu_operators = operators;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set optimzed model dir for Paddle Lite backend.
|
||||
@@ -349,9 +364,6 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
size_t trt_max_batch_size = 32;
|
||||
size_t trt_max_workspace_size = 1 << 30;
|
||||
|
||||
// ======Only for OpenVINO Backend======
|
||||
std::string openvino_device = "CPU";
|
||||
|
||||
// ======Only for Poros Backend=======
|
||||
bool is_dynamic = false;
|
||||
bool long_to_int = true;
|
||||
@@ -360,7 +372,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
std::string poros_file = "";
|
||||
|
||||
// ======Only for OpenVINO Backend=======
|
||||
int ov_num_streams = 1;
|
||||
int ov_num_streams = 0;
|
||||
std::string openvino_device = "CPU";
|
||||
std::map<std::string, std::vector<int64_t>> ov_shape_infos;
|
||||
std::vector<std::string> ov_cpu_operators;
|
||||
|
||||
// ======Only for RKNPU2 Backend=======
|
||||
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_
|
||||
|
@@ -35,7 +35,7 @@ class Runtime:
|
||||
self.runtime_option._option), "Initialize Runtime Failed!"
|
||||
|
||||
def forward(self, *inputs):
|
||||
"""Inference with input data for poros
|
||||
"""[Only for Poros backend] Inference with input data for poros
|
||||
|
||||
:param data: (list[str : numpy.ndarray])The input data list
|
||||
:return list of numpy.ndarray
|
||||
@@ -60,7 +60,7 @@ class Runtime:
|
||||
return self._runtime.infer(data)
|
||||
|
||||
def compile(self, warm_datas):
|
||||
"""compile with prewarm data for poros
|
||||
"""[Only for Poros backend] compile with prewarm data for poros
|
||||
|
||||
:param data: (list[str : numpy.ndarray])The prewarm data list
|
||||
:return TorchScript Model
|
||||
@@ -122,6 +122,9 @@ class RuntimeOption:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a FastDeploy RuntimeOption object.
|
||||
"""
|
||||
|
||||
self._option = C.RuntimeOption()
|
||||
|
||||
@property
|
||||
@@ -210,8 +213,6 @@ class RuntimeOption:
|
||||
def use_rknpu2(self,
|
||||
rknpu2_name=rknpu2.CpuName.RK3588,
|
||||
rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0):
|
||||
"""Inference with CPU
|
||||
"""
|
||||
return self._option.use_rknpu2(rknpu2_name, rknpu2_core)
|
||||
|
||||
def set_cpu_thread_num(self, thread_num=-1):
|
||||
@@ -222,6 +223,10 @@ class RuntimeOption:
|
||||
return self._option.set_cpu_thread_num(thread_num)
|
||||
|
||||
def set_ort_graph_opt_level(self, level=-1):
|
||||
"""Set graph optimization level for ONNX Runtime backend
|
||||
|
||||
:param level: (int)Optimization level, -1 means the default setting
|
||||
"""
|
||||
return self._option.set_ort_graph_opt_level(level)
|
||||
|
||||
def use_paddle_backend(self):
|
||||
@@ -274,6 +279,20 @@ class RuntimeOption:
|
||||
"""
|
||||
return self._option.set_openvino_device(name)
|
||||
|
||||
def set_openvino_shape_info(self, shape_info):
|
||||
"""Set shape information of the models' inputs, used for GPU to fix the shape
|
||||
|
||||
:param shape_info: (dict{str, list of int})Shape information of model's inputs, e.g {"image": [1, 3, 640, 640], "scale_factor": [1, 2]}
|
||||
"""
|
||||
return self._option.set_openvino_shape_info(shape_info)
|
||||
|
||||
def set_openvino_cpu_operators(self, operators):
|
||||
"""While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU
|
||||
|
||||
:param operators: (list of string)list of operators' name, e.g ["MulticlasNms"]
|
||||
"""
|
||||
return self._option.set_openvino_cpu_operators(operators)
|
||||
|
||||
def enable_paddle_log_info(self):
|
||||
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
|
||||
"""
|
||||
@@ -367,9 +386,13 @@ class RuntimeOption:
|
||||
return self._option.set_trt_max_batch_size(trt_max_batch_size)
|
||||
|
||||
def enable_paddle_trt_collect_shape(self):
|
||||
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
|
||||
"""
|
||||
return self._option.enable_paddle_trt_collect_shape()
|
||||
|
||||
def disable_paddle_trt_collect_shape(self):
|
||||
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT
|
||||
"""
|
||||
return self._option.disable_paddle_trt_collect_shape()
|
||||
|
||||
def use_ipu(self,
|
||||
|
Reference in New Issue
Block a user