[Other] Optimize paddle backend (#1265)

* Optimize paddle backend

* optimize paddle backend

* add version support
This commit is contained in:
Jason
2023-02-08 19:12:03 +08:00
committed by GitHub
parent 60ba4b06c1
commit a4b0565b9a
10 changed files with 265 additions and 174 deletions

View File

@@ -24,54 +24,71 @@
namespace fastdeploy {
/*! @brief Option object to configure GraphCore IPU
*/
struct IpuOption {
/// IPU device id
int ipu_device_num;
/// the batch size in the graph, only work when graph has no batch shape info
int ipu_micro_batch_size;
/// enable pipelining
bool ipu_enable_pipelining;
/// the number of batches per run in pipelining
int ipu_batches_per_step;
/// enable fp16
bool ipu_enable_fp16;
/// the number of graph replication
int ipu_replica_num;
/// the available memory proportion for matmul/conv
float ipu_available_memory_proportion;
/// enable fp16 partial for matmul, only work with fp16
bool ipu_enable_half_partial;
};
/*! @brief Option object to configure Paddle Inference backend
*/
struct PaddleBackendOption {
/// Print log information while initialize Paddle Inference backend
bool enable_log_info = false;
/// Enable MKLDNN while inference on CPU
bool enable_mkldnn = true;
/// Use Paddle Inference + TensorRT to inference model on GPU
bool enable_trt = false;
/*
* @brief IPU option, this will configure the IPU hardware, if inference model in IPU
*/
IpuOption ipu_option;
/// Collect shape for model while enabel_trt is true
bool collect_trt_shape = false;
/// Cache input shape for mkldnn while the input data will change dynamiclly
int mkldnn_cache_size = -1;
/// initialize memory size(MB) for GPU
int gpu_mem_init_size = 100;
void DisableTrtOps(const std::vector<std::string>& ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}
void DeletePass(const std::string& pass_name) {
delete_pass_names.push_back(pass_name);
}
// The belowing parameters may be removed, please do not
// read or write them directly
TrtBackendOption trt_option;
bool enable_pinned_memory = false;
void* external_stream_ = nullptr;
Device device = Device::CPU;
int device_id = 0;
std::vector<std::string> trt_disabled_ops_{};
int cpu_thread_num = 8;
std::vector<std::string> delete_pass_names = {};
std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty
// load model and paramters from memory
bool model_from_memory_ = false;
#ifdef WITH_GPU
bool use_gpu = true;
#else
bool use_gpu = false;
#endif
bool enable_mkldnn = true;
bool enable_log_info = false;
bool enable_trt = false;
TrtBackendOption trt_option;
bool collect_shape = false;
std::vector<std::string> trt_disabled_ops_{};
#ifdef WITH_IPU
bool use_ipu = true;
IpuOption ipu_option;
#else
bool use_ipu = false;
#endif
int mkldnn_cache_size = 1;
int cpu_thread_num = 8;
// initialize memory size(MB) for GPU
int gpu_mem_init_size = 100;
// gpu device id
int gpu_id = 0;
bool enable_pinned_memory = false;
void* external_stream_ = nullptr;
std::vector<std::string> delete_pass_names = {};
};
} // namespace fastdeploy

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
#include "fastdeploy/runtime/backends/paddle/option.h"
namespace fastdeploy {
void BindIpuOption(pybind11::module& m) {
pybind11::class_<IpuOption>(m, "IpuOption")
.def(pybind11::init())
.def_readwrite("ipu_device_num", &IpuOption::ipu_device_num)
.def_readwrite("ipu_micro_batch_size", &IpuOption::ipu_micro_batch_size)
.def_readwrite("ipu_enable_pipelining", &IpuOption::ipu_enable_pipelining)
.def_readwrite("ipu_batches_per_step", &IpuOption::ipu_batches_per_step)
.def_readwrite("ipu_enable_fp16", &IpuOption::ipu_enable_fp16)
.def_readwrite("ipu_replica_num", &IpuOption::ipu_replica_num)
.def_readwrite("ipu_available_memory_proportion",
&IpuOption::ipu_available_memory_proportion)
.def_readwrite("ipu_enable_half_partial",
&IpuOption::ipu_enable_half_partial);
}
void BindPaddleOption(pybind11::module& m) {
BindIpuOption(m);
pybind11::class_<PaddleBackendOption>(m, "PaddleBackendOption")
.def(pybind11::init())
.def_readwrite("enable_log_info", &PaddleBackendOption::enable_log_info)
.def_readwrite("enable_mkldnn", &PaddleBackendOption::enable_mkldnn)
.def_readwrite("enable_trt", &PaddleBackendOption::enable_trt)
.def_readwrite("ipu_option", &PaddleBackendOption::ipu_option)
.def_readwrite("collect_trt_shape",
&PaddleBackendOption::collect_trt_shape)
.def_readwrite("mkldnn_cache_size",
&PaddleBackendOption::mkldnn_cache_size)
.def_readwrite("gpu_mem_init_size",
&PaddleBackendOption::gpu_mem_init_size)
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
.def("delete_pass", &PaddleBackendOption::DeletePass);
}
} // namespace fastdeploy

View File

@@ -22,8 +22,8 @@ namespace fastdeploy {
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option;
if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
if (option.device == Device::GPU) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.device_id);
if (option_.external_stream_) {
config_.SetExecStream(option_.external_stream_);
}
@@ -50,7 +50,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
precision, use_static);
SetTRTDynamicShapeToConfig(option);
}
} else if (option.use_ipu) {
} else if (option.device == Device::IPU) {
#ifdef WITH_IPU
config_.EnableIpu(option.ipu_option.ipu_device_num,
option.ipu_option.ipu_micro_batch_size,
@@ -104,11 +104,12 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
// The input/output information get from predictor is not right, use
// PaddleReader instead now
auto reader = paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size());
auto reader =
paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size());
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to
// int8 mode
if (reader.is_quantize_model) {
if (option.use_gpu) {
if (option.device == Device::GPU) {
FDWARNING << "The loaded model is a quantized model, while inference on "
"GPU, please use TensorRT backend to get better performance."
<< std::endl;
@@ -158,7 +159,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
}
if (option.collect_shape) {
if (option.collect_trt_shape) {
// Set the shape info file.
std::string curr_model_dir = "./";
if (!option.model_from_memory_) {
@@ -233,7 +234,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
RUNTIME_PROFILE_LOOP_END
// output share backend memory only support CPU or GPU
if (option_.use_ipu) {
if (option_.device == Device::IPU) {
copy_to_fd = true;
}
outputs->resize(outputs_desc_.size());
@@ -253,9 +254,10 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
if (device_id > 0 && (option_.device == Device::GPU) &&
device_id != option_.device_id) {
auto clone_option = option_;
clone_option.gpu_id = device_id;
clone_option.device_id = device_id;
clone_option.external_stream_ = stream;
if (runtime_option.model_from_memory_) {
FDASSERT(
@@ -279,7 +281,7 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
}
FDWARNING << "The target device id:" << device_id
<< " is different from current device id:" << option_.gpu_id
<< " is different from current device id:" << option_.device_id
<< ", cannot share memory with current engine." << std::endl;
return new_backend;
}
@@ -347,10 +349,13 @@ void PaddleBackend::CollectShapeRun(
const std::map<std::string, std::vector<int>>& shape) const {
auto input_names = predictor->GetInputNames();
auto input_type = predictor->GetInputTypes();
for (auto name : input_names) {
for (const auto& name : input_names) {
FDASSERT(shape.find(name) != shape.end() &&
input_type.find(name) != input_type.end(),
"Paddle Input name [%s] is not one of the trt dynamic shape.",
"When collect_trt_shape is true, please define max/opt/min shape "
"for model's input:[\"%s\"] by "
"(C++)RuntimeOption.trt_option.SetShape/"
"(Python)RuntimeOption.trt_option.set_shape.",
name.c_str());
auto tensor = predictor->GetInputHandle(name);
auto shape_value = shape.at(name);

View File

@@ -20,6 +20,7 @@ void BindLiteOption(pybind11::module& m);
void BindOpenVINOOption(pybind11::module& m);
void BindOrtOption(pybind11::module& m);
void BindTrtOption(pybind11::module& m);
void BindPaddleOption(pybind11::module& m);
void BindPorosOption(pybind11::module& m);
void BindOption(pybind11::module& m) {
@@ -27,6 +28,7 @@ void BindOption(pybind11::module& m) {
BindOpenVINOOption(m);
BindOrtOption(m);
BindTrtOption(m);
BindPaddleOption(m);
BindPorosOption(m);
pybind11::class_<RuntimeOption>(m, "RuntimeOption")
@@ -44,6 +46,7 @@ void BindOption(pybind11::module& m) {
.def_readwrite("ort_option", &RuntimeOption::ort_option)
.def_readwrite("trt_option", &RuntimeOption::trt_option)
.def_readwrite("poros_option", &RuntimeOption::poros_option)
.def_readwrite("paddle_infer_option", &RuntimeOption::paddle_infer_option)
.def("set_external_stream", &RuntimeOption::SetExternalStream)
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
@@ -52,25 +55,11 @@ void BindOption(pybind11::module& m) {
.def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
.def("use_lite_backend", &RuntimeOption::UseLiteBackend)
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size",
&RuntimeOption::SetPaddleMKLDNNCacheSize)
.def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape",
&RuntimeOption::EnablePaddleTrtCollectShape)
.def("disable_paddle_trt_collect_shape",
&RuntimeOption::DisablePaddleTrtCollectShape)
.def("use_ipu", &RuntimeOption::UseIpu)
.def("set_ipu_config", &RuntimeOption::SetIpuConfig)
.def("delete_paddle_backend_pass",
&RuntimeOption::DeletePaddleBackendPass)
.def("enable_profiling", &RuntimeOption::EnableProfiling)
.def("disable_profiling", &RuntimeOption::DisableProfiling)
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
.def_readwrite("model_file", &RuntimeOption::model_file)
.def_readwrite("params_file", &RuntimeOption::params_file)
.def_readwrite("model_format", &RuntimeOption::model_format)
@@ -79,19 +68,6 @@ void BindOption(pybind11::module& m) {
.def_readwrite("model_from_memory", &RuntimeOption::model_from_memory_)
.def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num)
.def_readwrite("device_id", &RuntimeOption::device_id)
.def_readwrite("device", &RuntimeOption::device)
.def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num)
.def_readwrite("ipu_micro_batch_size",
&RuntimeOption::ipu_micro_batch_size)
.def_readwrite("ipu_enable_pipelining",
&RuntimeOption::ipu_enable_pipelining)
.def_readwrite("ipu_batches_per_step",
&RuntimeOption::ipu_batches_per_step)
.def_readwrite("ipu_enable_fp16", &RuntimeOption::ipu_enable_fp16)
.def_readwrite("ipu_replica_num", &RuntimeOption::ipu_replica_num)
.def_readwrite("ipu_available_memory_proportion",
&RuntimeOption::ipu_available_memory_proportion)
.def_readwrite("ipu_enable_half_partial",
&RuntimeOption::ipu_enable_half_partial);
.def_readwrite("device", &RuntimeOption::device);
}
} // namespace fastdeploy

View File

@@ -226,52 +226,23 @@ void Runtime::CreatePaddleBackend() {
option.model_format == ModelFormat::PADDLE,
"Backend::PDINFER only supports model format of ModelFormat::PADDLE.");
#ifdef ENABLE_PADDLE_BACKEND
auto pd_option = PaddleBackendOption();
pd_option.model_file = option.model_file;
pd_option.params_file = option.params_file;
pd_option.enable_mkldnn = option.pd_enable_mkldnn;
pd_option.enable_log_info = option.pd_enable_log_info;
pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size;
pd_option.use_gpu = (option.device == Device::GPU) ? true : false;
pd_option.use_ipu = (option.device == Device::IPU) ? true : false;
pd_option.gpu_id = option.device_id;
pd_option.delete_pass_names = option.pd_delete_pass_names;
pd_option.cpu_thread_num = option.cpu_thread_num;
pd_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.external_stream_ = option.external_stream_;
pd_option.model_from_memory_ = option.model_from_memory_;
#ifdef ENABLE_TRT_BACKEND
if (pd_option.use_gpu && option.pd_enable_trt) {
pd_option.enable_trt = true;
pd_option.collect_shape = option.pd_collect_shape;
pd_option.trt_option = option.trt_option;
pd_option.trt_option.gpu_id = option.device_id;
pd_option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
}
#endif
#ifdef WITH_IPU
if (pd_option.use_ipu) {
auto ipu_option = IpuOption();
ipu_option.ipu_device_num = option.ipu_device_num;
ipu_option.ipu_micro_batch_size = option.ipu_micro_batch_size;
ipu_option.ipu_enable_pipelining = option.ipu_enable_pipelining;
ipu_option.ipu_batches_per_step = option.ipu_batches_per_step;
ipu_option.ipu_enable_fp16 = option.ipu_enable_fp16;
ipu_option.ipu_replica_num = option.ipu_replica_num;
ipu_option.ipu_available_memory_proportion =
option.ipu_available_memory_proportion;
ipu_option.ipu_enable_half_partial = option.ipu_enable_half_partial;
pd_option.ipu_option = ipu_option;
}
#endif
option.paddle_infer_option.model_file = option.model_file;
option.paddle_infer_option.params_file = option.params_file;
option.paddle_infer_option.model_from_memory_ = option.model_from_memory_;
option.paddle_infer_option.device = option.device;
option.paddle_infer_option.device_id = option.device_id;
option.paddle_infer_option.enable_pinned_memory = option.enable_pinned_memory;
option.paddle_infer_option.external_stream_ = option.external_stream_;
option.paddle_infer_option.trt_option = option.trt_option;
option.paddle_infer_option.trt_option.gpu_id = option.device_id;
backend_ = utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
if (pd_option.model_from_memory_) {
FDASSERT(casted_backend->InitFromPaddle(option.model_file,
option.params_file, pd_option),
if (option.model_from_memory_) {
FDASSERT(
casted_backend->InitFromPaddle(option.model_file, option.params_file,
option.paddle_infer_option),
"Load model from Paddle failed while initliazing PaddleBackend.");
ReleaseModelMemoryBuffer();
} else {
@@ -281,8 +252,8 @@ void Runtime::CreatePaddleBackend() {
"Fail to read binary from model file");
FDASSERT(ReadBinaryFromFile(option.params_file, &params_buffer),
"Fail to read binary from parameter file");
FDASSERT(
casted_backend->InitFromPaddle(model_buffer, params_buffer, pd_option),
FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer,
option.paddle_infer_option),
"Load model from Paddle failed while initliazing PaddleBackend.");
}
#else

View File

@@ -99,6 +99,7 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
paddle_lite_option.cpu_threads = thread_num;
ort_option.intra_op_num_threads = thread_num;
openvino_option.cpu_thread_num = thread_num;
paddle_infer_option.cpu_thread_num = thread_num;
}
void RuntimeOption::SetOrtGraphOptLevel(int level) {
@@ -174,25 +175,47 @@ void RuntimeOption::UseLiteBackend() {
}
void RuntimeOption::SetPaddleMKLDNN(bool pd_mkldnn) {
pd_enable_mkldnn = pd_mkldnn;
FDWARNING << "`RuntimeOption::SetPaddleMKLDNN` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_mkldnn = true`"
<< std::endl;
paddle_infer_option.enable_mkldnn = pd_mkldnn;
}
void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) {
pd_delete_pass_names.push_back(pass_name);
FDWARNING
<< "`RuntimeOption::DeletePaddleBackendPass` will be removed in v1.2.0, "
"please use `option.paddle_infer_option.DeletePass` instead."
<< std::endl;
paddle_infer_option.DeletePass(pass_name);
}
void RuntimeOption::EnablePaddleLogInfo() {
FDWARNING << "`RuntimeOption::EnablePaddleLogInfo` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_log_info = true`"
<< std::endl;
paddle_infer_option.enable_log_info = true;
}
void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; }
void RuntimeOption::DisablePaddleLogInfo() {
FDWARNING << "`RuntimeOption::DisablePaddleLogInfo` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_log_info = false`"
<< std::endl;
paddle_infer_option.enable_log_info = false;
}
void RuntimeOption::EnablePaddleToTrt() {
FDASSERT(backend == Backend::TRT,
"Should call UseTrtBackend() before call EnablePaddleToTrt().");
#ifdef ENABLE_PADDLE_BACKEND
FDWARNING << "`RuntimeOption::EnablePaddleToTrt` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_trt = true`"
<< std::endl;
FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will "
"change to use Paddle Inference Backend."
<< std::endl;
backend = Backend::PDINFER;
pd_enable_trt = true;
paddle_infer_option.enable_trt = true;
#else
FDASSERT(false,
"While using TrtBackend with EnablePaddleToTrt, require the "
@@ -202,8 +225,11 @@ void RuntimeOption::EnablePaddleToTrt() {
}
void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) {
FDASSERT(size > 0, "Parameter size must greater than 0.");
pd_mkldnn_cache_size = size;
FDWARNING << "`RuntimeOption::SetPaddleMKLDNNCacheSize` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.mkldnn_cache_size = size`."
<< std::endl;
paddle_infer_option.mkldnn_cache_size = size;
}
void RuntimeOption::SetOpenVINODevice(const std::string& name) {
@@ -393,12 +419,28 @@ void RuntimeOption::SetOpenVINOStreams(int num_streams) {
openvino_option.num_streams = num_streams;
}
void RuntimeOption::EnablePaddleTrtCollectShape() { pd_collect_shape = true; }
void RuntimeOption::EnablePaddleTrtCollectShape() {
FDWARNING << "`RuntimeOption::EnablePaddleTrtCollectShape` will be removed "
"in v1.2.0, please modify its member variable directly, e.g "
"runtime_option.paddle_infer_option.collect_trt_shape = true`."
<< std::endl;
paddle_infer_option.collect_trt_shape = true;
}
void RuntimeOption::DisablePaddleTrtCollectShape() { pd_collect_shape = false; }
void RuntimeOption::DisablePaddleTrtCollectShape() {
FDWARNING << "`RuntimeOption::DisablePaddleTrtCollectShape` will be removed "
"in v1.2.0, please modify its member variable directly, e.g "
"runtime_option.paddle_infer_option.collect_trt_shape = false`."
<< std::endl;
paddle_infer_option.collect_trt_shape = false;
}
void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
FDWARNING << "`RuntimeOption::DisablePaddleTrtOps` will be removed in "
"v.1.20, please use "
"`runtime_option.paddle_infer_option.DisableTrtOps` instead."
<< std::endl;
paddle_infer_option.DisableTrtOps(ops);
}
void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
@@ -419,10 +461,11 @@ void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
void RuntimeOption::SetIpuConfig(bool enable_fp16, int replica_num,
float available_memory_proportion,
bool enable_half_partial) {
ipu_enable_fp16 = enable_fp16;
ipu_replica_num = replica_num;
ipu_available_memory_proportion = available_memory_proportion;
ipu_enable_half_partial = enable_half_partial;
paddle_infer_option.ipu_option.ipu_enable_fp16 = enable_fp16;
paddle_infer_option.ipu_option.ipu_replica_num = replica_num;
paddle_infer_option.ipu_option.ipu_available_memory_proportion =
available_memory_proportion;
paddle_infer_option.ipu_option.ipu_enable_half_partial = enable_half_partial;
}
} // namespace fastdeploy

View File

@@ -378,27 +378,12 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Option to configure ONNX Runtime backend
OrtBackendOption ort_option;
// ======Only for Paddle Backend=====
bool pd_enable_mkldnn = true;
bool pd_enable_log_info = false;
bool pd_enable_trt = false;
bool pd_collect_shape = false;
int pd_mkldnn_cache_size = 1;
std::vector<std::string> pd_delete_pass_names;
// ======Only for Paddle IPU Backend =======
int ipu_device_num = 1;
int ipu_micro_batch_size = 1;
bool ipu_enable_pipelining = false;
int ipu_batches_per_step = 1;
bool ipu_enable_fp16 = false;
int ipu_replica_num = 1;
float ipu_available_memory_proportion = 1.0;
bool ipu_enable_half_partial = false;
/// Option to configure TensorRT backend
TrtBackendOption trt_option;
/// Option to configure Paddle Inference backend
PaddleBackendOption paddle_infer_option;
// ======Only for PaddleTrt Backend=======
std::vector<std::string> trt_disabled_ops_{};

View File

@@ -39,3 +39,5 @@ from . import text
from . import encryption
from .download import download, download_and_decompress, download_model, get_model_list
from . import serving
from .code_version import version, git_version
__version__ = version

View File

@@ -364,7 +364,10 @@ class RuntimeOption:
def set_paddle_mkldnn(self, use_mkldnn=True):
"""Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default.
"""
return self._option.set_paddle_mkldnn(use_mkldnn)
logging.warning(
"`RuntimeOption.set_paddle_mkldnn` will be derepcated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_mkldnn = True` instead."
)
self._option.paddle_infer_option.enable_mkldnn = True
def set_openvino_device(self, name="CPU"):
"""Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'....
@@ -400,17 +403,26 @@ class RuntimeOption:
def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
"""
return self._option.enable_paddle_log_info()
logging.warning(
"RuntimeOption.enable_paddle_log_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_log_info = True` instead."
)
self._option.paddle_infer_option.enable_log_info = True
def disable_paddle_log_info(self):
"""Disable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
"""
return self._option.disable_paddle_log_info()
logging.warning(
"RuntimeOption.disable_paddle_log_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_log_info = False` instead."
)
self._option.paddle_infer_option.enable_log_info = False
def set_paddle_mkldnn_cache_size(self, cache_size):
"""Set size of shape cache while using Paddle Inference backend with MKLDNN enabled, default will cache all the dynamic shape.
"""
return self._option.set_paddle_mkldnn_cache_size(cache_size)
logging.warning(
"RuntimeOption.set_paddle_mkldnn_cache_size` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.mkldnn_cache_size = {}` instead.".
format(cache_size))
self._option.paddle_infer_option.mkldnn_cache_size = cache_size
def enable_lite_fp16(self):
"""Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
@@ -498,6 +510,16 @@ class RuntimeOption:
def enable_paddle_to_trt(self):
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
"""
logging.warning(
"`RuntimeOption.enable_paddle_to_trt` will be deprecated in v1.2.l0, if you want to run tensorrt with Paddle Inference backend, please use the following method, "
)
logging.warning(" ==============================================")
logging.warning(" import fastdeploy as fd")
logging.warning(" option = fd.RuntimeOption()")
logging.warning(" option.use_gpu(0)")
logging.warning(" option.use_paddle_infer_backend()")
logging.warning(" option.paddle_infer_option.enabel_trt = True")
logging.warning(" ==============================================")
return self._option.enable_paddle_to_trt()
def set_trt_max_workspace_size(self, trt_max_workspace_size):
@@ -519,22 +541,34 @@ class RuntimeOption:
def enable_paddle_trt_collect_shape(self):
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.enable_paddle_trt_collect_shape()
logging.warning(
"`RuntimeOption.enable_paddle_trt_collect_shape` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.collect_trt_shape = True` instead."
)
self._option.paddle_infer_option.collect_trt_shape = True
def disable_paddle_trt_collect_shape(self):
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT
"""
return self._option.disable_paddle_trt_collect_shape()
logging.warning(
"`RuntimeOption.disable_paddle_trt_collect_shape` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.collect_trt_shape = False` instead."
)
self._option.paddle_infer_option.collect_trt_shape = False
def delete_paddle_backend_pass(self, pass_name):
"""Delete pass by name in paddle backend
"""
return self._option.delete_paddle_backend_pass(pass_name)
logging.warning(
"`RuntimeOption.delete_paddle_backend_pass` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.delete_pass` instead."
)
self._option.paddle_infer_option.delete_pass(pass_name)
def disable_paddle_trt_ops(self, ops):
"""Disable some ops in paddle trt backend
"""
return self._option.disable_paddle_trt_ops(ops)
logging.warning(
"`RuntimeOption.disable_paddle_trt_ops` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.disable_trt_ops()` instead."
)
self._option.disable_trt_ops(ops)
def use_ipu(self,
device_num=1,
@@ -593,6 +627,14 @@ class RuntimeOption:
"""
return self._option.trt_option
@property
def paddle_infer_option(self):
"""Get PaddleBackendOption object to configure Paddle Inference backend
:return PaddleBackendOption
"""
return self._option.paddle_infer_option
def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50):
"""Set the profile mode as 'true'.
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.

View File

@@ -64,8 +64,7 @@ setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
"OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF")
@@ -80,8 +79,7 @@ setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")
setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF")
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
"/usr/local/cuda")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda")
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
@@ -104,6 +102,7 @@ if os.getenv("CMAKE_CXX_COMPILER", None) is not None:
setup_configs["CMAKE_CXX_COMPILER"] = os.getenv("CMAKE_CXX_COMPILER")
SRC_DIR = os.path.join(TOP_DIR, PACKAGE_NAME)
PYTHON_SRC_DIR = os.path.join(TOP_DIR, "python", PACKAGE_NAME)
CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'python', '.setuptools-cmake-build')
WINDOWS = (os.name == 'nt')
@@ -120,8 +119,7 @@ extras_require = {}
# Default value is set to TRUE\1 to keep the settings same as the current ones.
# However going forward the recomemded way to is to set this to False\0
USE_MSVC_STATIC_RUNTIME = bool(
os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx')
################################################################################
# Version
@@ -151,8 +149,7 @@ assert CMAKE, 'Could not find "cmake" executable!'
@contextmanager
def cd(path):
if not os.path.isabs(path):
raise RuntimeError('Can only cd to absolute path, got: {}'.format(
path))
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
orig_path = os.getcwd()
os.chdir(path)
try:
@@ -187,7 +184,7 @@ def get_all_files(dirname):
class create_version(ONNXCommand):
def run(self):
with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f:
with open(os.path.join(PYTHON_SRC_DIR, 'code_version.py'), 'w') as f:
f.write(
dedent('''\
# This file is generated by setup.py. DO NOT EDIT!