diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 3b48e36a1..000000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "image": "mcr.microsoft.com/devcontainers/universal:2", - "features": { - "ghcr.io/devcontainers/features/nvidia-cuda:1": { - "version": "latest" - } - } -} diff --git a/fastdeploy/backends/openvino/ov_backend.cc b/fastdeploy/backends/openvino/ov_backend.cc index 608e9199f..da3ec5404 100644 --- a/fastdeploy/backends/openvino/ov_backend.cc +++ b/fastdeploy/backends/openvino/ov_backend.cc @@ -32,6 +32,14 @@ std::vector PartialShapeToVec(const ov::PartialShape& shape) { return res; } +ov::PartialShape VecToPartialShape(const std::vector& shape) { + std::vector dims; + for (size_t i = 0; i < shape.size(); ++i) { + dims.emplace_back(ov::Dimension(shape[i])); + } + return ov::PartialShape(dims); +} + FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) { if (type == ov::element::f32) { return FDDataType::FP32; @@ -100,6 +108,26 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file, option_ = option; std::shared_ptr model = core_.read_model(model_file, params_file); + if (option_.shape_infos.size() > 0) { + std::map shape_infos; + for (const auto& item : option_.shape_infos) { + shape_infos[item.first] = VecToPartialShape(item.second); + } + model->reshape(shape_infos); + } + + if (option_.device.find("HETERO") != std::string::npos) { + auto supported_ops = core_.query_model(model, option_.device); + for (auto&& op : model->get_ops()) { + auto& affinity = supported_ops[op->get_friendly_name()]; + if (option_.cpu_operators.find(op->description()) != + option_.cpu_operators.end()) { + op->get_rt_info()["affinity"] = "CPU"; + } else { + op->get_rt_info()["affinity"] = affinity; + } + } + } // Get inputs/outputs information from loaded model const std::vector> inputs = model->inputs(); @@ -151,14 +179,25 @@ bool OpenVINOBackend::InitFromPaddle(const std::string& model_file, if (option_.cpu_thread_num > 0) { properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; } - if (option_.ov_num_streams == -1) { - properties["NUM_STREAMS"] = ov::streams::AUTO; - } else if (option_.ov_num_streams == -2) { - properties["NUM_STREAMS"] = ov::streams::NUMA; - } else if (option_.ov_num_streams > 0) { - properties["NUM_STREAMS"] = option_.ov_num_streams; + if (option_.device == "CPU") { + if (option_.num_streams == -1) { + properties["NUM_STREAMS"] = ov::streams::AUTO; + } else if (option_.num_streams == -2) { + properties["NUM_STREAMS"] = ov::streams::NUMA; + } else if (option_.num_streams > 0) { + properties["NUM_STREAMS"] = option_.num_streams; + } + } else { + if (option_.num_streams != 0) { + FDWARNING << "NUM_STREAMS only available on device CPU, currently the " + "device is set as " + << option_.device << ", the NUM_STREAMS will be ignored." + << std::endl; + } } - FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl; + + FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." + << std::endl; compiled_model_ = core_.compile_model(model, option.device, properties); request_ = compiled_model_.create_infer_request(); @@ -199,6 +238,27 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file, std::shared_ptr model = core_.read_model(model_file); + if (option_.shape_infos.size() > 0) { + std::map shape_infos; + for (const auto& item : option_.shape_infos) { + shape_infos[item.first] = VecToPartialShape(item.second); + } + model->reshape(shape_infos); + } + + if (option_.device.find("HETERO") != std::string::npos) { + auto supported_ops = core_.query_model(model, option_.device); + for (auto&& op : model->get_ops()) { + auto& affinity = supported_ops[op->get_friendly_name()]; + if (option_.cpu_operators.find(op->description()) != + option_.cpu_operators.end()) { + op->get_rt_info()["affinity"] = "CPU"; + } else { + op->get_rt_info()["affinity"] = affinity; + } + } + } + // Get inputs/outputs information from loaded model const std::vector> inputs = model->inputs(); std::map input_infos; @@ -249,18 +309,29 @@ bool OpenVINOBackend::InitFromOnnx(const std::string& model_file, if (option_.cpu_thread_num > 0) { properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num; } - if (option_.ov_num_streams == -1) { - properties["NUM_STREAMS"] = ov::streams::AUTO; - } else if (option_.ov_num_streams == -2) { - properties["NUM_STREAMS"] = ov::streams::NUMA; - } else if (option_.ov_num_streams > 0) { - properties["NUM_STREAMS"] = option_.ov_num_streams; + if (option_.device == "CPU") { + if (option_.num_streams == -1) { + properties["NUM_STREAMS"] = ov::streams::AUTO; + } else if (option_.num_streams == -2) { + properties["NUM_STREAMS"] = ov::streams::NUMA; + } else if (option_.num_streams > 0) { + properties["NUM_STREAMS"] = option_.num_streams; + } + } else { + if (option_.num_streams != 0) { + FDWARNING << "NUM_STREAMS only available on device CPU, currently the " + "device is set as " + << option_.device << ", the NUM_STREAMS will be ignored." + << std::endl; + } } - FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." << std::endl; + + FDINFO << "Compile OpenVINO model on device_name:" << option.device << "." + << std::endl; compiled_model_ = core_.compile_model(model, option.device, properties); request_ = compiled_model_.create_infer_request(); - + initialized_ = true; return true; } @@ -302,13 +373,16 @@ bool OpenVINOBackend::Infer(std::vector& inputs, return true; } -std::unique_ptr OpenVINOBackend::Clone(void *stream, int device_id) { - std::unique_ptr new_backend = utils::make_unique(); +std::unique_ptr OpenVINOBackend::Clone(void* stream, + int device_id) { + std::unique_ptr new_backend = + utils::make_unique(); auto casted_backend = dynamic_cast(new_backend.get()); casted_backend->option_ = option_; casted_backend->request_ = compiled_model_.create_infer_request(); casted_backend->input_infos_.assign(input_infos_.begin(), input_infos_.end()); - casted_backend->output_infos_.assign(output_infos_.begin(), output_infos_.end()); + casted_backend->output_infos_.assign(output_infos_.begin(), + output_infos_.end()); return new_backend; } diff --git a/fastdeploy/backends/openvino/ov_backend.h b/fastdeploy/backends/openvino/ov_backend.h index f28459f03..e224cdca5 100644 --- a/fastdeploy/backends/openvino/ov_backend.h +++ b/fastdeploy/backends/openvino/ov_backend.h @@ -28,8 +28,9 @@ namespace fastdeploy { struct OpenVINOBackendOption { std::string device = "CPU"; int cpu_thread_num = -1; - int ov_num_streams = 1; + int num_streams = 0; std::map> shape_infos; + std::set cpu_operators{"MulticlassNms"}; }; class OpenVINOBackend : public BaseBackend { @@ -38,13 +39,13 @@ class OpenVINOBackend : public BaseBackend { OpenVINOBackend() {} virtual ~OpenVINOBackend() = default; - bool InitFromPaddle( - const std::string& model_file, const std::string& params_file, - const OpenVINOBackendOption& option = OpenVINOBackendOption()); + bool + InitFromPaddle(const std::string& model_file, const std::string& params_file, + const OpenVINOBackendOption& option = OpenVINOBackendOption()); - bool InitFromOnnx( - const std::string& model_file, - const OpenVINOBackendOption& option = OpenVINOBackendOption()); + bool + InitFromOnnx(const std::string& model_file, + const OpenVINOBackendOption& option = OpenVINOBackendOption()); bool Infer(std::vector& inputs, std::vector* outputs) override; @@ -58,7 +59,7 @@ class OpenVINOBackend : public BaseBackend { std::vector GetInputInfos() override; std::vector GetOutputInfos() override; - std::unique_ptr Clone(void *stream = nullptr, + std::unique_ptr Clone(void* stream = nullptr, int device_id = -1) override; private: @@ -71,4 +72,5 @@ class OpenVINOBackend : public BaseBackend { std::vector input_infos_; std::vector output_infos_; }; + } // namespace fastdeploy diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index c46f83dbc..53e2e6266 100644 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -34,6 +34,8 @@ void BindRuntime(pybind11::module& m) { .def("use_lite_backend", &RuntimeOption::UseLiteBackend) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) + .def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo) + .def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("set_paddle_mkldnn_cache_size", diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 6a8be7dd7..b9924cb0b 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -646,7 +646,11 @@ void Runtime::CreateOpenVINOBackend() { auto ov_option = OpenVINOBackendOption(); ov_option.cpu_thread_num = option.cpu_thread_num; ov_option.device = option.openvino_device; - ov_option.ov_num_streams = option.ov_num_streams; + ov_option.shape_infos = option.ov_shape_infos; + ov_option.num_streams = option.ov_num_streams; + for (const auto& op : option.ov_cpu_operators) { + ov_option.cpu_operators.insert(op); + } FDASSERT(option.model_format == ModelFormat::PADDLE || option.model_format == ModelFormat::ONNX, "OpenVINOBackend only support model format of ModelFormat::PADDLE / " diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index 9c8a12976..6ea584026 100755 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -171,7 +171,22 @@ struct FASTDEPLOY_DECL RuntimeOption { /** * @brief Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'.... */ - void SetOpenVINODevice(const std::string& name = "CPU"); + void SetOpenVINODevice(const std::string& name = "CPU"); + + /** + * @brief Set shape info for OpenVINO + */ + void SetOpenVINOShapeInfo( + const std::map>& shape_info) { + ov_shape_infos = shape_info; + } + + /** + * @brief While use OpenVINO backend with intel GPU, use this interface to specify operators run on CPU + */ + void SetOpenVINOCpuOperators(const std::vector& operators) { + ov_cpu_operators = operators; + } /** * @brief Set optimzed model dir for Paddle Lite backend. @@ -349,9 +364,6 @@ struct FASTDEPLOY_DECL RuntimeOption { size_t trt_max_batch_size = 32; size_t trt_max_workspace_size = 1 << 30; - // ======Only for OpenVINO Backend====== - std::string openvino_device = "CPU"; - // ======Only for Poros Backend======= bool is_dynamic = false; bool long_to_int = true; @@ -360,7 +372,10 @@ struct FASTDEPLOY_DECL RuntimeOption { std::string poros_file = ""; // ======Only for OpenVINO Backend======= - int ov_num_streams = 1; + int ov_num_streams = 0; + std::string openvino_device = "CPU"; + std::map> ov_shape_infos; + std::vector ov_cpu_operators; // ======Only for RKNPU2 Backend======= fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 27b99b760..5352b6b75 100755 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -35,7 +35,7 @@ class Runtime: self.runtime_option._option), "Initialize Runtime Failed!" def forward(self, *inputs): - """Inference with input data for poros + """[Only for Poros backend] Inference with input data for poros :param data: (list[str : numpy.ndarray])The input data list :return list of numpy.ndarray @@ -60,7 +60,7 @@ class Runtime: return self._runtime.infer(data) def compile(self, warm_datas): - """compile with prewarm data for poros + """[Only for Poros backend] compile with prewarm data for poros :param data: (list[str : numpy.ndarray])The prewarm data list :return TorchScript Model @@ -122,6 +122,9 @@ class RuntimeOption: """ def __init__(self): + """Initialize a FastDeploy RuntimeOption object. + """ + self._option = C.RuntimeOption() @property @@ -210,8 +213,6 @@ class RuntimeOption: def use_rknpu2(self, rknpu2_name=rknpu2.CpuName.RK3588, rknpu2_core=rknpu2.CoreMask.RKNN_NPU_CORE_0): - """Inference with CPU - """ return self._option.use_rknpu2(rknpu2_name, rknpu2_core) def set_cpu_thread_num(self, thread_num=-1): @@ -222,6 +223,10 @@ class RuntimeOption: return self._option.set_cpu_thread_num(thread_num) def set_ort_graph_opt_level(self, level=-1): + """Set graph optimization level for ONNX Runtime backend + + :param level: (int)Optimization level, -1 means the default setting + """ return self._option.set_ort_graph_opt_level(level) def use_paddle_backend(self): @@ -274,6 +279,20 @@ class RuntimeOption: """ return self._option.set_openvino_device(name) + def set_openvino_shape_info(self, shape_info): + """Set shape information of the models' inputs, used for GPU to fix the shape + + :param shape_info: (dict{str, list of int})Shape information of model's inputs, e.g {"image": [1, 3, 640, 640], "scale_factor": [1, 2]} + """ + return self._option.set_openvino_shape_info(shape_info) + + def set_openvino_cpu_operators(self, operators): + """While using OpenVINO backend and intel GPU, this interface specifies unsupported operators to run on CPU + + :param operators: (list of string)list of operators' name, e.g ["MulticlasNms"] + """ + return self._option.set_openvino_cpu_operators(operators) + def enable_paddle_log_info(self): """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. """ @@ -367,9 +386,13 @@ class RuntimeOption: return self._option.set_trt_max_batch_size(trt_max_batch_size) def enable_paddle_trt_collect_shape(self): + """Enable collect subgraph shape information while using Paddle Inference with TensorRT + """ return self._option.enable_paddle_trt_collect_shape() def disable_paddle_trt_collect_shape(self): + """Disable collect subgraph shape information while using Paddle Inference with TensorRT + """ return self._option.disable_paddle_trt_collect_shape() def use_ipu(self,