Add OpenVINO backend support (#148)

* Add OpenVINO backend support

* fix pybind

* fix python library path
This commit is contained in:
Jason
2022-08-24 16:22:38 +08:00
committed by GitHub
parent a1260d7937
commit cf4afa4220
20 changed files with 479 additions and 38 deletions

View File

@@ -42,6 +42,7 @@ option(WITH_GPU "Whether WITH_GPU=ON, will enable onnxruntime-gpu/paddle-infernc
option(ENABLE_ORT_BACKEND "Whether to enable onnxruntime backend." OFF)
option(ENABLE_TRT_BACKEND "Whether to enable tensorrt backend." OFF)
option(ENABLE_PADDLE_BACKEND "Whether to enable paddle backend." OFF)
option(ENABLE_OPENVINO_BACKEND "Whether to enable paddle backend." OFF)
option(CUDA_DIRECTORY "If build tensorrt backend, need to define path of cuda library.")
option(TRT_DIRECTORY "If build tensorrt backend, need to define path of tensorrt library.")
option(ENABLE_VISION "Whether to enable vision models usage." OFF)
@@ -117,10 +118,11 @@ file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fas
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/tensorrt/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/tensorrt/*.cpp)
file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${FDTENSOR_FUNC_SRCS})
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS} ${FDTENSOR_FUNC_SRCS})
set(DEPEND_LIBS "")
@@ -157,6 +159,13 @@ if(ENABLE_PADDLE_BACKEND)
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
add_definitions(-DENABLE_OPENVINO_BACKEND)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OPENVINO_SRCS})
include(external/openvino.cmake)
list(APPEND DEPEND_LIBS external_openvino)
endif()
if(WITH_GPU)
if(APPLE)
message(FATAL_ERROR "Cannot enable GPU while compling in Mac OSX.")

View File

@@ -3,7 +3,9 @@ CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
set(WITH_GPU @WITH_GPU@)
set(ENABLE_ORT_BACKEND @ENABLE_ORT_BACKEND@)
set(ENABLE_PADDLE_BACKEND @ENABLE_PADDLE_BACKEND@)
set(ENABLE_OPENVINO_BACKEND @ENABLE_OPENVINO_BACKEND@)
set(PADDLEINFERENCE_VERSION @PADDLEINFERENCE_VERSION@)
set(OPENVINO_VERSION @OPENVINO_VERSION@)
set(ENABLE_TRT_BACKEND @ENABLE_TRT_BACKEND@)
set(ENABLE_PADDLE_FRONTEND @ENABLE_PADDLE_FRONTEND@)
set(ENABLE_VISION @ENABLE_VISION@)
@@ -45,6 +47,11 @@ if(ENABLE_PADDLE_BACKEND)
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
find_library(OPENVINO_LIB openvino ${CMAKE_CURRENT_LIST_DIR}/third_libs/install/openvino/lib/ NO_DEFAULT_PATH)
list(APPEND FASTDEPLOY_LIBS ${OPENVINO_LIB})
endif()
if(WITH_GPU)
if (NOT CUDA_DIRECTORY)
set(CUDA_DIRECTORY "/usr/local/cuda")
@@ -124,6 +131,10 @@ message(STATUS " ENABLE_PADDLE_BACKEND : ${ENABLE_PADDLE_BACKEND}")
if(ENABLE_PADDLE_BACKEND)
message(STATUS " Paddle Inference version : ${PADDLEINFERENCE_VERSION}")
endif()
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
if(ENABLE_OPENVINO_BACKEND)
message(STATUS " OpenVINO version : ${OPENVINO_VERSION}")
endif()
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")

View File

@@ -1 +1 @@
0.2.0
0.2.1

View File

@@ -14,12 +14,12 @@
#pragma once
#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "fastdeploy/backends/common/multiclass_nms.h"
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
@@ -27,6 +27,20 @@ struct TensorInfo {
std::string name;
std::vector<int> shape;
FDDataType dtype;
friend std::ostream& operator<<(std::ostream& output,
const TensorInfo& info) {
output << "TensorInfo(name: " << info.name << ", shape: [";
for (size_t i = 0; i < info.shape.size(); ++i) {
if (i == info.shape.size() - 1) {
output << info.shape[i];
} else {
output << info.shape[i] << ", ";
}
}
output << "], dtype: " << Str(info.dtype) << ")";
return output;
}
};
class BaseBackend {

View File

@@ -0,0 +1,199 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/openvino/ov_backend.h"
namespace fastdeploy {
std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
std::vector<int64_t> res;
for (int i = 0; i < shape.size(); ++i) {
auto dim = shape[i];
if (dim.is_dynamic()) {
res.push_back(-1);
} else {
res.push_back(dim.get_length());
}
}
return res;
}
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
if (type == ov::element::f32) {
return FDDataType::FP32;
} else if (type == ov::element::f64) {
return FDDataType::FP64;
} else if (type == ov::element::i8) {
return FDDataType::INT8;
} else if (type == ov::element::i32) {
return FDDataType::INT32;
} else if (type == ov::element::i64) {
return FDDataType::INT64;
} else {
FDASSERT(false, "Only support float/double/int8/int32/int64 now.");
}
return FDDataType::FP32;
}
ov::element::Type FDDataTypeToOV(const FDDataType& type) {
if (type == FDDataType::FP32) {
return ov::element::f32;
} else if (type == FDDataType::FP64) {
return ov::element::f64;
} else if (type == FDDataType::INT8) {
return ov::element::i8;
} else if (type == FDDataType::INT32) {
return ov::element::i32;
} else if (type == FDDataType::INT64) {
return ov::element::i64;
}
FDASSERT(false, "Only support float/double/int8/int32/int64 now.");
return ov::element::f32;
}
bool OpenVINOBackend::InitFromPaddle(const std::string& model_file,
const std::string& params_file,
const OpenVINOBackendOption& option) {
if (initialized_) {
FDERROR << "OpenVINOBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
std::shared_ptr<ov::Model> model = core_.read_model(model_file, params_file);
// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(inputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = inputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(inputs[i].get_element_type());
input_infos_.emplace_back(info);
}
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(outputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = outputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(outputs[i].get_element_type());
output_infos_.emplace_back(info);
}
compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
}
TensorInfo OpenVINOBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of outputs: %d.", index,
NumOutputs());
return input_infos_[index];
}
TensorInfo OpenVINOBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs: %d.", index,
NumOutputs());
return output_infos_[index];
}
bool OpenVINOBackend::InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option) {
if (initialized_) {
FDERROR << "OpenVINOBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
option_ = option;
ov::AnyMap properties;
if (option_.cpu_thread_num > 0) {
properties["INFERENCE_NUM_THREADS"] = option_.cpu_thread_num;
}
std::shared_ptr<ov::Model> model = core_.read_model(model_file);
// Get inputs/outputs information from loaded model
const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(inputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = inputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(inputs[i].get_element_type());
input_infos_.emplace_back(info);
}
const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
for (size_t i = 0; i < outputs.size(); ++i) {
TensorInfo info;
auto partial_shape = PartialShapeToVec(outputs[i].get_partial_shape());
info.shape.assign(partial_shape.begin(), partial_shape.end());
info.name = outputs[i].get_any_name();
info.dtype = OpenVINODataTypeToFD(outputs[i].get_element_type());
output_infos_.emplace_back(info);
}
compiled_model_ = core_.compile_model(model, "CPU", properties);
request_ = compiled_model_.create_infer_request();
initialized_ = true;
return true;
}
int OpenVINOBackend::NumInputs() const { return input_infos_.size(); }
int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); }
bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != input_infos_.size()) {
FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size()
<< ") should keep same with the inputs of this model("
<< input_infos_.size() << ")." << std::endl;
return false;
}
for (size_t i = 0; i < inputs.size(); ++i) {
ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end());
ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape,
inputs[i].Data());
request_.set_tensor(inputs[i].name, ov_tensor);
}
request_.infer();
outputs->resize(output_infos_.size());
for (size_t i = 0; i < output_infos_.size(); ++i) {
auto out_tensor = request_.get_output_tensor(i);
auto out_tensor_shape = out_tensor.get_shape();
std::vector<int64_t> shape(out_tensor_shape.begin(),
out_tensor_shape.end());
(*outputs)[i].Allocate(shape,
OpenVINODataTypeToFD(out_tensor.get_element_type()),
output_infos_[i].name);
memcpy((*outputs)[i].MutableData(), out_tensor.data(),
(*outputs)[i].Nbytes());
}
return true;
}
} // namespace fastdeploy

View File

@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "fastdeploy/backends/backend.h"
#include "openvino/openvino.hpp"
namespace fastdeploy {
struct OpenVINOBackendOption {
int cpu_thread_num = 8;
std::map<std::string, std::vector<int64_t>> shape_infos;
};
class OpenVINOBackend : public BaseBackend {
public:
OpenVINOBackend() {}
virtual ~OpenVINOBackend() = default;
bool
InitFromPaddle(const std::string& model_file, const std::string& params_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool
InitFromOnnx(const std::string& model_file,
const OpenVINOBackendOption& option = OpenVINOBackendOption());
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);
int NumInputs() const;
int NumOutputs() const;
TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index);
private:
ov::Core core_;
ov::CompiledModel compiled_model_;
ov::InferRequest request_;
OpenVINOBackendOption option_;
std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_;
};
} // namespace fastdeploy

View File

@@ -32,10 +32,10 @@ void OrtBackend::BuildOption(const OrtBackendOption& option) {
session_options_.SetGraphOptimizationLevel(
GraphOptimizationLevel(option.graph_optimization_level));
}
if (option.intra_op_num_threads >= 0) {
if (option.intra_op_num_threads > 0) {
session_options_.SetIntraOpNumThreads(option.intra_op_num_threads);
}
if (option.inter_op_num_threads >= 0) {
if (option.inter_op_num_threads > 0) {
session_options_.SetInterOpNumThreads(option.inter_op_num_threads);
}
if (option.execution_mode >= 0) {

View File

@@ -29,8 +29,12 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
if (!option.enable_log_info) {
config_.DisableGlogInfo();
}
if (option.cpu_thread_num <= 0) {
config_.SetCpuMathLibraryNumThreads(8);
} else {
config_.SetCpuMathLibraryNumThreads(option.cpu_thread_num);
}
}
bool PaddleBackend::InitFromPaddle(const std::string& model_file,
const std::string& params_file,

View File

@@ -33,6 +33,10 @@
#cmakedefine ENABLE_PADDLE_BACKEND
#endif
#ifndef ENABLE_OPENVINO_BACKEND
#cmakedefine ENABLE_OPENVINO_BACKEND
#endif
#ifndef WITH_GPU
#cmakedefine WITH_GPU
#endif

View File

@@ -26,31 +26,8 @@ bool FastDeployModel::InitRuntime() {
return false;
}
if (runtime_option.backend != Backend::UNKNOWN) {
if (runtime_option.backend == Backend::ORT) {
if (!IsBackendAvailable(Backend::ORT)) {
FDERROR
<< "Backend::ORT is not complied with current FastDeploy library."
<< std::endl;
return false;
}
} else if (runtime_option.backend == Backend::TRT) {
if (!IsBackendAvailable(Backend::TRT)) {
FDERROR
<< "Backend::TRT is not complied with current FastDeploy library."
<< std::endl;
return false;
}
} else if (runtime_option.backend == Backend::PDINFER) {
if (!IsBackendAvailable(Backend::PDINFER)) {
FDERROR << "Backend::PDINFER is not compiled with current FastDeploy "
"library."
<< std::endl;
return false;
}
} else {
FDERROR
<< "Only support Backend::ORT / Backend::TRT / Backend::PDINFER now."
<< std::endl;
if (!IsBackendAvailable(runtime_option.backend)) {
FDERROR << Str(runtime_option.backend) << " is not compiled with current FastDeploy library." << std::endl;
return false;
}

View File

@@ -28,6 +28,10 @@
#include "fastdeploy/backends/paddle/paddle_backend.h"
#endif
#ifdef ENABLE_OPENVINO_BACKEND
#include "fastdeploy/backends/openvino/ov_backend.h"
#endif
namespace fastdeploy {
std::vector<Backend> GetAvailableBackends() {
@@ -40,6 +44,9 @@ std::vector<Backend> GetAvailableBackends() {
#endif
#ifdef ENABLE_PADDLE_BACKEND
backends.push_back(Backend::PDINFER);
#endif
#ifdef ENABLE_OPENVINO_BACKEND
backends.push_back(Backend::OPENVINO);
#endif
return backends;
}
@@ -61,6 +68,8 @@ std::string Str(const Backend& b) {
return "Backend::TRT";
} else if (b == Backend::PDINFER) {
return "Backend::PDINFER";
} else if (b == Backend::OPENVINO) {
return "Backend::OPENVINO";
}
return "UNKNOWN-Backend";
}
@@ -177,6 +186,13 @@ void RuntimeOption::UseTrtBackend() {
#endif
}
void RuntimeOption::UseOpenVINOBackend() {
#ifdef ENABLE_OPENVINO_BACKEND
backend = Backend::OPENVINO;
#else
FDASSERT(false, "The FastDeploy didn't compile with OpenVINO.");
#endif
}
void RuntimeOption::EnablePaddleMKLDNN() { pd_enable_mkldnn = true; }
void RuntimeOption::DisablePaddleMKLDNN() { pd_enable_mkldnn = false; }
@@ -228,21 +244,26 @@ bool Runtime::Init(const RuntimeOption& _option) {
option.backend = Backend::ORT;
} else if (IsBackendAvailable(Backend::PDINFER)) {
option.backend = Backend::PDINFER;
} else {
} else if (IsBackendAvailable(Backend::OPENVINO)) {
option.backend = Backend::OPENVINO;
} {
FDERROR << "Please define backend in RuntimeOption, current it's "
"Backend::UNKNOWN."
<< std::endl;
return false;
}
}
if (option.backend == Backend::ORT) {
FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
"Backend::TRT only supports Device::CPU/Device::GPU.");
"Backend::ORT only supports Device::CPU/Device::GPU.");
CreateOrtBackend();
FDINFO << "Runtime initialized with Backend::ORT." << std::endl;
} else if (option.backend == Backend::TRT) {
FDASSERT(option.device == Device::GPU,
"Backend::TRT only supports Device::GPU.");
CreateTrtBackend();
FDINFO << "Runtime initialized with Backend::TRT." << std::endl;
} else if (option.backend == Backend::PDINFER) {
FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
"Backend::TRT only supports Device::CPU/Device::GPU.");
@@ -250,6 +271,11 @@ bool Runtime::Init(const RuntimeOption& _option) {
option.model_format == Frontend::PADDLE,
"Backend::PDINFER only supports model format of Frontend::PADDLE.");
CreatePaddleBackend();
FDINFO << "Runtime initialized with Backend::PDINFER." << std::endl;
} else if (option.backend == Backend::OPENVINO) {
FDASSERT(option.device == Device::CPU, "Backend::OPENVINO only supports Device::CPU");
CreateOpenVINOBackend();
FDINFO << "Runtime initialized with Backend::OPENVINO." << std::endl;
} else {
FDERROR << "Runtime only support "
"Backend::ORT/Backend::TRT/Backend::PDINFER as backend now."
@@ -295,6 +321,32 @@ void Runtime::CreatePaddleBackend() {
#endif
}
void Runtime::CreateOpenVINOBackend() {
#ifdef ENABLE_OPENVINO_BACKEND
auto ov_option = OpenVINOBackendOption();
ov_option.cpu_thread_num = option.cpu_thread_num;
FDASSERT(option.model_format == Frontend::PADDLE ||
option.model_format == Frontend::ONNX,
"OpenVINOBackend only support model format of Frontend::PADDLE / "
"Frontend::ONNX.");
backend_ = utils::make_unique<OpenVINOBackend>();
auto casted_backend = dynamic_cast<OpenVINOBackend*>(backend_.get());
if (option.model_format == Frontend::ONNX) {
FDASSERT(casted_backend->InitFromOnnx(option.model_file, ov_option),
"Load model from ONNX failed while initliazing OrtBackend.");
} else {
FDASSERT(casted_backend->InitFromPaddle(option.model_file,
option.params_file, ov_option),
"Load model from Paddle failed while initliazing OrtBackend.");
}
#else
FDASSERT(false,
"OpenVINOBackend is not available, please compiled with "
"ENABLE_OPENVINO_BACKEND=ON.");
#endif
}
void Runtime::CreateOrtBackend() {
#ifdef ENABLE_ORT_BACKEND
auto ort_option = OrtBackendOption();

View File

@@ -21,7 +21,7 @@
namespace fastdeploy {
enum FASTDEPLOY_DECL Backend { UNKNOWN, ORT, TRT, PDINFER };
enum FASTDEPLOY_DECL Backend { UNKNOWN, ORT, TRT, PDINFER, OPENVINO };
// AUTOREC will according to the name of model file
// to decide which Frontend is
enum FASTDEPLOY_DECL Frontend { AUTOREC, PADDLE, ONNX };
@@ -63,6 +63,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
// use tensorrt backend
void UseTrtBackend();
// use openvino backend
void UseOpenVINOBackend();
// enable mkldnn while use paddle inference in CPU
void EnablePaddleMKLDNN();
// disable mkldnn while use paddle inference in CPU
@@ -97,7 +100,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
Backend backend = Backend::UNKNOWN;
// for cpu inference and preprocess
int cpu_thread_num = 8;
// default will let the backend choose their own default value
int cpu_thread_num = -1;
int device_id = 0;
Device device = Device::CPU;
@@ -152,6 +156,8 @@ struct FASTDEPLOY_DECL Runtime {
void CreateTrtBackend();
void CreateOpenVINOBackend();
int NumInputs() { return backend_->NumInputs(); }
int NumOutputs() { return backend_->NumOutputs(); }
TensorInfo GetInputInfo(int index);

View File

@@ -26,6 +26,7 @@ void BindRuntime(pybind11::module& m) {
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
.def("use_ort_backend", &RuntimeOption::UseOrtBackend)
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
.def("enable_paddle_mkldnn", &RuntimeOption::EnablePaddleMKLDNN)
.def("disable_paddle_mkldnn", &RuntimeOption::DisablePaddleMKLDNN)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)

View File

@@ -26,7 +26,7 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;

View File

@@ -14,7 +14,7 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;

91
external/openvino.cmake vendored Normal file
View File

@@ -0,0 +1,91 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
set(OPENVINO_PROJECT "extern_openvino")
set(OPENVINO_PREFIX_DIR ${THIRD_PARTY_PATH}/openvino)
set(OPENVINO_SOURCE_DIR
${THIRD_PARTY_PATH}/openvino/src/${OPENVINO_PROJECT})
set(OPENVINO_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openvino)
set(OPENVINO_INC_DIR
"${OPENVINO_INSTALL_DIR}/include"
CACHE PATH "openvino include directory." FORCE)
set(OPENVINO_LIB_DIR
"${OPENVINO_INSTALL_DIR}/lib/"
CACHE PATH "openvino lib directory." FORCE)
set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}" "${OPENVINO_LIB_DIR}")
set(OPENVINO_VERSION "2022.3.0")
set(OPENVINO_URL_PREFIX "https://bj.bcebos.com/fastdeploy/third_libs/")
if(WIN32)
message(FATAL_ERROR "FastDeploy cannot ENABLE_OPENVINO_BACKEND in windows now.")
set(OPENVINO_FILENAME "openvino-win-x64-${OPENVINO_VERSION}.zip")
if(NOT CMAKE_CL_64)
message(FATAL_ERROR "FastDeploy cannot ENABLE_OPENVINO_BACKEND in win32 now.")
endif()
elseif(APPLE)
message(FATAL_ERROR "FastDeploy cannot ENABLE_OPENVINO_BACKEND in Mac OSX now.")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64")
set(OPENVINO_FILENAME "openvino-osx-arm64-${OPENVINO_VERSION}.tgz")
else()
set(OPENVINO_FILENAME "openvino-osx-x86_64-${OPENVINO_VERSION}.tgz")
endif()
else()
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
message("Cannot compile with openvino while in linux-aarch64 platform")
else()
set(OPENVINO_FILENAME "openvino-linux-x64-${OPENVINO_VERSION}.tgz")
endif()
endif()
set(OPENVINO_URL "${OPENVINO_URL_PREFIX}${OPENVINO_FILENAME}")
include_directories(${OPENVINO_INC_DIR}
)# For OPENVINO code to include internal headers.
if(WIN32)
set(OPENVINO_LIB
"${OPENVINO_INSTALL_DIR}/lib/openvino.lib"
CACHE FILEPATH "OPENVINO static library." FORCE)
elseif(APPLE)
set(OPENVINO_LIB
"${OPENVINO_INSTALL_DIR}/lib/libopenvino.dylib"
CACHE FILEPATH "OPENVINO static library." FORCE)
else()
set(OPENVINO_LIB
"${OPENVINO_INSTALL_DIR}/lib/libopenvino.so"
CACHE FILEPATH "OPENVINO static library." FORCE)
endif()
ExternalProject_Add(
${OPENVINO_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
URL ${OPENVINO_URL}
PREFIX ${OPENVINO_PREFIX_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND
${CMAKE_COMMAND} -E remove_directory ${OPENVINO_INSTALL_DIR} &&
${CMAKE_COMMAND} -E make_directory ${OPENVINO_INSTALL_DIR} &&
${CMAKE_COMMAND} -E rename ${OPENVINO_SOURCE_DIR}/lib/intel64 ${OPENVINO_INSTALL_DIR}/lib &&
${CMAKE_COMMAND} -E copy_directory ${OPENVINO_SOURCE_DIR}/include
${OPENVINO_INC_DIR}
BUILD_BYPRODUCTS ${OPENVINO_LIB})
add_library(external_openvino STATIC IMPORTED GLOBAL)
set_property(TARGET external_openvino PROPERTY IMPORTED_LOCATION ${OPENVINO_LIB})
add_dependencies(external_openvino ${OPENVINO_PROJECT})

View File

@@ -33,12 +33,16 @@ function(fastdeploy_summary)
message(STATUS " ENABLE_ORT_BACKEND : ${ENABLE_ORT_BACKEND}")
message(STATUS " ENABLE_PADDLE_BACKEND : ${ENABLE_PADDLE_BACKEND}")
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
if(ENABLE_ORT_BACKEND)
message(STATUS " ONNXRuntime version : ${ONNXRUNTIME_VERSION}")
endif()
if(ENABLE_PADDLE_BACKEND)
message(STATUS " Paddle Inference version : ${PADDLEINFERENCE_VERSION}")
endif()
if(ENABLE_OPENVINO_BACKEND)
message(STATUS " OpenVINO version : ${OPENVINO_VERSION}")
endif()
if(WITH_GPU)
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")

View File

@@ -32,6 +32,9 @@ def is_built_with_trt() -> bool:
def is_built_with_paddle() -> bool:
return True if "@ENABLE_PADDLE_BACKEND@" == "ON" else False
def is_built_with_openvino() ->bool:
return True if "@ENABLE_OPENVINO_BACKEND@" == "ON" else False
def get_default_cuda_directory() -> str:
if not is_built_with_gpu():

View File

@@ -76,6 +76,9 @@ class RuntimeOption:
def use_trt_backend(self):
return self._option.use_trt_backend()
def use_openvino_backend(self):
return self._option.use_openvino_backend()
def enable_paddle_mkldnn(self):
return self._option.enable_paddle_mkldnn()

View File

@@ -47,6 +47,7 @@ setup_configs = dict()
setup_configs["ENABLE_PADDLE_FRONTEND"] = os.getenv("ENABLE_PADDLE_FRONTEND",
"ON")
setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "ON")
setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND", "OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "ON")