[Other] Optimize paddle backend (#1265)

* Optimize paddle backend

* optimize paddle backend

* add version support
This commit is contained in:
Jason
2023-02-08 19:12:03 +08:00
committed by GitHub
parent 60ba4b06c1
commit a4b0565b9a
10 changed files with 265 additions and 174 deletions

View File

@@ -24,54 +24,71 @@
namespace fastdeploy { namespace fastdeploy {
/*! @brief Option object to configure GraphCore IPU
*/
struct IpuOption { struct IpuOption {
/// IPU device id
int ipu_device_num; int ipu_device_num;
/// the batch size in the graph, only work when graph has no batch shape info
int ipu_micro_batch_size; int ipu_micro_batch_size;
/// enable pipelining
bool ipu_enable_pipelining; bool ipu_enable_pipelining;
/// the number of batches per run in pipelining
int ipu_batches_per_step; int ipu_batches_per_step;
/// enable fp16
bool ipu_enable_fp16; bool ipu_enable_fp16;
/// the number of graph replication
int ipu_replica_num; int ipu_replica_num;
/// the available memory proportion for matmul/conv
float ipu_available_memory_proportion; float ipu_available_memory_proportion;
/// enable fp16 partial for matmul, only work with fp16
bool ipu_enable_half_partial; bool ipu_enable_half_partial;
}; };
/*! @brief Option object to configure Paddle Inference backend
*/
struct PaddleBackendOption { struct PaddleBackendOption {
/// Print log information while initialize Paddle Inference backend
bool enable_log_info = false;
/// Enable MKLDNN while inference on CPU
bool enable_mkldnn = true;
/// Use Paddle Inference + TensorRT to inference model on GPU
bool enable_trt = false;
/*
* @brief IPU option, this will configure the IPU hardware, if inference model in IPU
*/
IpuOption ipu_option;
/// Collect shape for model while enabel_trt is true
bool collect_trt_shape = false;
/// Cache input shape for mkldnn while the input data will change dynamiclly
int mkldnn_cache_size = -1;
/// initialize memory size(MB) for GPU
int gpu_mem_init_size = 100;
void DisableTrtOps(const std::vector<std::string>& ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}
void DeletePass(const std::string& pass_name) {
delete_pass_names.push_back(pass_name);
}
// The belowing parameters may be removed, please do not
// read or write them directly
TrtBackendOption trt_option;
bool enable_pinned_memory = false;
void* external_stream_ = nullptr;
Device device = Device::CPU;
int device_id = 0;
std::vector<std::string> trt_disabled_ops_{};
int cpu_thread_num = 8;
std::vector<std::string> delete_pass_names = {};
std::string model_file = ""; // Path of model file std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty std::string params_file = ""; // Path of parameters file, can be empty
// load model and paramters from memory // load model and paramters from memory
bool model_from_memory_ = false; bool model_from_memory_ = false;
#ifdef WITH_GPU
bool use_gpu = true;
#else
bool use_gpu = false;
#endif
bool enable_mkldnn = true;
bool enable_log_info = false;
bool enable_trt = false;
TrtBackendOption trt_option;
bool collect_shape = false;
std::vector<std::string> trt_disabled_ops_{};
#ifdef WITH_IPU
bool use_ipu = true;
IpuOption ipu_option;
#else
bool use_ipu = false;
#endif
int mkldnn_cache_size = 1;
int cpu_thread_num = 8;
// initialize memory size(MB) for GPU
int gpu_mem_init_size = 100;
// gpu device id
int gpu_id = 0;
bool enable_pinned_memory = false;
void* external_stream_ = nullptr;
std::vector<std::string> delete_pass_names = {};
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
#include "fastdeploy/runtime/backends/paddle/option.h"
namespace fastdeploy {
void BindIpuOption(pybind11::module& m) {
pybind11::class_<IpuOption>(m, "IpuOption")
.def(pybind11::init())
.def_readwrite("ipu_device_num", &IpuOption::ipu_device_num)
.def_readwrite("ipu_micro_batch_size", &IpuOption::ipu_micro_batch_size)
.def_readwrite("ipu_enable_pipelining", &IpuOption::ipu_enable_pipelining)
.def_readwrite("ipu_batches_per_step", &IpuOption::ipu_batches_per_step)
.def_readwrite("ipu_enable_fp16", &IpuOption::ipu_enable_fp16)
.def_readwrite("ipu_replica_num", &IpuOption::ipu_replica_num)
.def_readwrite("ipu_available_memory_proportion",
&IpuOption::ipu_available_memory_proportion)
.def_readwrite("ipu_enable_half_partial",
&IpuOption::ipu_enable_half_partial);
}
void BindPaddleOption(pybind11::module& m) {
BindIpuOption(m);
pybind11::class_<PaddleBackendOption>(m, "PaddleBackendOption")
.def(pybind11::init())
.def_readwrite("enable_log_info", &PaddleBackendOption::enable_log_info)
.def_readwrite("enable_mkldnn", &PaddleBackendOption::enable_mkldnn)
.def_readwrite("enable_trt", &PaddleBackendOption::enable_trt)
.def_readwrite("ipu_option", &PaddleBackendOption::ipu_option)
.def_readwrite("collect_trt_shape",
&PaddleBackendOption::collect_trt_shape)
.def_readwrite("mkldnn_cache_size",
&PaddleBackendOption::mkldnn_cache_size)
.def_readwrite("gpu_mem_init_size",
&PaddleBackendOption::gpu_mem_init_size)
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
.def("delete_pass", &PaddleBackendOption::DeletePass);
}
} // namespace fastdeploy

View File

@@ -22,8 +22,8 @@ namespace fastdeploy {
void PaddleBackend::BuildOption(const PaddleBackendOption& option) { void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option; option_ = option;
if (option.use_gpu) { if (option.device == Device::GPU) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); config_.EnableUseGpu(option.gpu_mem_init_size, option.device_id);
if (option_.external_stream_) { if (option_.external_stream_) {
config_.SetExecStream(option_.external_stream_); config_.SetExecStream(option_.external_stream_);
} }
@@ -50,7 +50,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
precision, use_static); precision, use_static);
SetTRTDynamicShapeToConfig(option); SetTRTDynamicShapeToConfig(option);
} }
} else if (option.use_ipu) { } else if (option.device == Device::IPU) {
#ifdef WITH_IPU #ifdef WITH_IPU
config_.EnableIpu(option.ipu_option.ipu_device_num, config_.EnableIpu(option.ipu_option.ipu_device_num,
option.ipu_option.ipu_micro_batch_size, option.ipu_option.ipu_micro_batch_size,
@@ -101,14 +101,15 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
params_buffer.c_str(), params_buffer.size()); params_buffer.c_str(), params_buffer.size());
config_.EnableMemoryOptim(); config_.EnableMemoryOptim();
BuildOption(option); BuildOption(option);
// The input/output information get from predictor is not right, use // The input/output information get from predictor is not right, use
// PaddleReader instead now // PaddleReader instead now
auto reader = paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size()); auto reader =
paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size());
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to // If it's a quantized model, and use cpu with mkldnn, automaticaly switch to
// int8 mode // int8 mode
if (reader.is_quantize_model) { if (reader.is_quantize_model) {
if (option.use_gpu) { if (option.device == Device::GPU) {
FDWARNING << "The loaded model is a quantized model, while inference on " FDWARNING << "The loaded model is a quantized model, while inference on "
"GPU, please use TensorRT backend to get better performance." "GPU, please use TensorRT backend to get better performance."
<< std::endl; << std::endl;
@@ -158,7 +159,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype); outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
} }
if (option.collect_shape) { if (option.collect_trt_shape) {
// Set the shape info file. // Set the shape info file.
std::string curr_model_dir = "./"; std::string curr_model_dir = "./";
if (!option.model_from_memory_) { if (!option.model_from_memory_) {
@@ -221,19 +222,19 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
<< inputs_desc_.size() << ")." << std::endl; << inputs_desc_.size() << ")." << std::endl;
return false; return false;
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name); auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromFDTensor(handle.get(), inputs[i]); ShareTensorFromFDTensor(handle.get(), inputs[i]);
} }
RUNTIME_PROFILE_LOOP_BEGIN(1) RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run(); predictor_->Run();
RUNTIME_PROFILE_LOOP_END RUNTIME_PROFILE_LOOP_END
// output share backend memory only support CPU or GPU // output share backend memory only support CPU or GPU
if (option_.use_ipu) { if (option_.device == Device::IPU) {
copy_to_fd = true; copy_to_fd = true;
} }
outputs->resize(outputs_desc_.size()); outputs->resize(outputs_desc_.size());
@@ -253,9 +254,10 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
std::unique_ptr<BaseBackend> new_backend = std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<PaddleBackend>(); utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get()); auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) { if (device_id > 0 && (option_.device == Device::GPU) &&
device_id != option_.device_id) {
auto clone_option = option_; auto clone_option = option_;
clone_option.gpu_id = device_id; clone_option.device_id = device_id;
clone_option.external_stream_ = stream; clone_option.external_stream_ = stream;
if (runtime_option.model_from_memory_) { if (runtime_option.model_from_memory_) {
FDASSERT( FDASSERT(
@@ -279,7 +281,7 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
} }
FDWARNING << "The target device id:" << device_id FDWARNING << "The target device id:" << device_id
<< " is different from current device id:" << option_.gpu_id << " is different from current device id:" << option_.device_id
<< ", cannot share memory with current engine." << std::endl; << ", cannot share memory with current engine." << std::endl;
return new_backend; return new_backend;
} }
@@ -347,10 +349,13 @@ void PaddleBackend::CollectShapeRun(
const std::map<std::string, std::vector<int>>& shape) const { const std::map<std::string, std::vector<int>>& shape) const {
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
auto input_type = predictor->GetInputTypes(); auto input_type = predictor->GetInputTypes();
for (auto name : input_names) { for (const auto& name : input_names) {
FDASSERT(shape.find(name) != shape.end() && FDASSERT(shape.find(name) != shape.end() &&
input_type.find(name) != input_type.end(), input_type.find(name) != input_type.end(),
"Paddle Input name [%s] is not one of the trt dynamic shape.", "When collect_trt_shape is true, please define max/opt/min shape "
"for model's input:[\"%s\"] by "
"(C++)RuntimeOption.trt_option.SetShape/"
"(Python)RuntimeOption.trt_option.set_shape.",
name.c_str()); name.c_str());
auto tensor = predictor->GetInputHandle(name); auto tensor = predictor->GetInputHandle(name);
auto shape_value = shape.at(name); auto shape_value = shape.at(name);
@@ -385,4 +390,4 @@ void PaddleBackend::CollectShapeRun(
predictor->Run(); predictor->Run();
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -20,6 +20,7 @@ void BindLiteOption(pybind11::module& m);
void BindOpenVINOOption(pybind11::module& m); void BindOpenVINOOption(pybind11::module& m);
void BindOrtOption(pybind11::module& m); void BindOrtOption(pybind11::module& m);
void BindTrtOption(pybind11::module& m); void BindTrtOption(pybind11::module& m);
void BindPaddleOption(pybind11::module& m);
void BindPorosOption(pybind11::module& m); void BindPorosOption(pybind11::module& m);
void BindOption(pybind11::module& m) { void BindOption(pybind11::module& m) {
@@ -27,6 +28,7 @@ void BindOption(pybind11::module& m) {
BindOpenVINOOption(m); BindOpenVINOOption(m);
BindOrtOption(m); BindOrtOption(m);
BindTrtOption(m); BindTrtOption(m);
BindPaddleOption(m);
BindPorosOption(m); BindPorosOption(m);
pybind11::class_<RuntimeOption>(m, "RuntimeOption") pybind11::class_<RuntimeOption>(m, "RuntimeOption")
@@ -44,6 +46,7 @@ void BindOption(pybind11::module& m) {
.def_readwrite("ort_option", &RuntimeOption::ort_option) .def_readwrite("ort_option", &RuntimeOption::ort_option)
.def_readwrite("trt_option", &RuntimeOption::trt_option) .def_readwrite("trt_option", &RuntimeOption::trt_option)
.def_readwrite("poros_option", &RuntimeOption::poros_option) .def_readwrite("poros_option", &RuntimeOption::poros_option)
.def_readwrite("paddle_infer_option", &RuntimeOption::paddle_infer_option)
.def("set_external_stream", &RuntimeOption::SetExternalStream) .def("set_external_stream", &RuntimeOption::SetExternalStream)
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
@@ -52,25 +55,11 @@ void BindOption(pybind11::module& m) {
.def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)
.def("use_lite_backend", &RuntimeOption::UseLiteBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend)
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size",
&RuntimeOption::SetPaddleMKLDNNCacheSize)
.def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape",
&RuntimeOption::EnablePaddleTrtCollectShape)
.def("disable_paddle_trt_collect_shape",
&RuntimeOption::DisablePaddleTrtCollectShape)
.def("use_ipu", &RuntimeOption::UseIpu) .def("use_ipu", &RuntimeOption::UseIpu)
.def("set_ipu_config", &RuntimeOption::SetIpuConfig)
.def("delete_paddle_backend_pass",
&RuntimeOption::DeletePaddleBackendPass)
.def("enable_profiling", &RuntimeOption::EnableProfiling) .def("enable_profiling", &RuntimeOption::EnableProfiling)
.def("disable_profiling", &RuntimeOption::DisableProfiling) .def("disable_profiling", &RuntimeOption::DisableProfiling)
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
.def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("model_file", &RuntimeOption::model_file)
.def_readwrite("params_file", &RuntimeOption::params_file) .def_readwrite("params_file", &RuntimeOption::params_file)
.def_readwrite("model_format", &RuntimeOption::model_format) .def_readwrite("model_format", &RuntimeOption::model_format)
@@ -79,19 +68,6 @@ void BindOption(pybind11::module& m) {
.def_readwrite("model_from_memory", &RuntimeOption::model_from_memory_) .def_readwrite("model_from_memory", &RuntimeOption::model_from_memory_)
.def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num) .def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num)
.def_readwrite("device_id", &RuntimeOption::device_id) .def_readwrite("device_id", &RuntimeOption::device_id)
.def_readwrite("device", &RuntimeOption::device) .def_readwrite("device", &RuntimeOption::device);
.def_readwrite("ipu_device_num", &RuntimeOption::ipu_device_num)
.def_readwrite("ipu_micro_batch_size",
&RuntimeOption::ipu_micro_batch_size)
.def_readwrite("ipu_enable_pipelining",
&RuntimeOption::ipu_enable_pipelining)
.def_readwrite("ipu_batches_per_step",
&RuntimeOption::ipu_batches_per_step)
.def_readwrite("ipu_enable_fp16", &RuntimeOption::ipu_enable_fp16)
.def_readwrite("ipu_replica_num", &RuntimeOption::ipu_replica_num)
.def_readwrite("ipu_available_memory_proportion",
&RuntimeOption::ipu_available_memory_proportion)
.def_readwrite("ipu_enable_half_partial",
&RuntimeOption::ipu_enable_half_partial);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -226,53 +226,24 @@ void Runtime::CreatePaddleBackend() {
option.model_format == ModelFormat::PADDLE, option.model_format == ModelFormat::PADDLE,
"Backend::PDINFER only supports model format of ModelFormat::PADDLE."); "Backend::PDINFER only supports model format of ModelFormat::PADDLE.");
#ifdef ENABLE_PADDLE_BACKEND #ifdef ENABLE_PADDLE_BACKEND
auto pd_option = PaddleBackendOption(); option.paddle_infer_option.model_file = option.model_file;
pd_option.model_file = option.model_file; option.paddle_infer_option.params_file = option.params_file;
pd_option.params_file = option.params_file; option.paddle_infer_option.model_from_memory_ = option.model_from_memory_;
pd_option.enable_mkldnn = option.pd_enable_mkldnn; option.paddle_infer_option.device = option.device;
pd_option.enable_log_info = option.pd_enable_log_info; option.paddle_infer_option.device_id = option.device_id;
pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size; option.paddle_infer_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.use_gpu = (option.device == Device::GPU) ? true : false; option.paddle_infer_option.external_stream_ = option.external_stream_;
pd_option.use_ipu = (option.device == Device::IPU) ? true : false; option.paddle_infer_option.trt_option = option.trt_option;
pd_option.gpu_id = option.device_id; option.paddle_infer_option.trt_option.gpu_id = option.device_id;
pd_option.delete_pass_names = option.pd_delete_pass_names;
pd_option.cpu_thread_num = option.cpu_thread_num;
pd_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.external_stream_ = option.external_stream_;
pd_option.model_from_memory_ = option.model_from_memory_;
#ifdef ENABLE_TRT_BACKEND
if (pd_option.use_gpu && option.pd_enable_trt) {
pd_option.enable_trt = true;
pd_option.collect_shape = option.pd_collect_shape;
pd_option.trt_option = option.trt_option;
pd_option.trt_option.gpu_id = option.device_id;
pd_option.trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
}
#endif
#ifdef WITH_IPU
if (pd_option.use_ipu) {
auto ipu_option = IpuOption();
ipu_option.ipu_device_num = option.ipu_device_num;
ipu_option.ipu_micro_batch_size = option.ipu_micro_batch_size;
ipu_option.ipu_enable_pipelining = option.ipu_enable_pipelining;
ipu_option.ipu_batches_per_step = option.ipu_batches_per_step;
ipu_option.ipu_enable_fp16 = option.ipu_enable_fp16;
ipu_option.ipu_replica_num = option.ipu_replica_num;
ipu_option.ipu_available_memory_proportion =
option.ipu_available_memory_proportion;
ipu_option.ipu_enable_half_partial = option.ipu_enable_half_partial;
pd_option.ipu_option = ipu_option;
}
#endif
backend_ = utils::make_unique<PaddleBackend>(); backend_ = utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get()); auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option; casted_backend->benchmark_option_ = option.benchmark_option;
if (pd_option.model_from_memory_) { if (option.model_from_memory_) {
FDASSERT(casted_backend->InitFromPaddle(option.model_file, FDASSERT(
option.params_file, pd_option), casted_backend->InitFromPaddle(option.model_file, option.params_file,
"Load model from Paddle failed while initliazing PaddleBackend."); option.paddle_infer_option),
"Load model from Paddle failed while initliazing PaddleBackend.");
ReleaseModelMemoryBuffer(); ReleaseModelMemoryBuffer();
} else { } else {
std::string model_buffer = ""; std::string model_buffer = "";
@@ -281,9 +252,9 @@ void Runtime::CreatePaddleBackend() {
"Fail to read binary from model file"); "Fail to read binary from model file");
FDASSERT(ReadBinaryFromFile(option.params_file, &params_buffer), FDASSERT(ReadBinaryFromFile(option.params_file, &params_buffer),
"Fail to read binary from parameter file"); "Fail to read binary from parameter file");
FDASSERT( FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer,
casted_backend->InitFromPaddle(model_buffer, params_buffer, pd_option), option.paddle_infer_option),
"Load model from Paddle failed while initliazing PaddleBackend."); "Load model from Paddle failed while initliazing PaddleBackend.");
} }
#else #else
FDASSERT(false, FDASSERT(false,

View File

@@ -99,6 +99,7 @@ void RuntimeOption::SetCpuThreadNum(int thread_num) {
paddle_lite_option.cpu_threads = thread_num; paddle_lite_option.cpu_threads = thread_num;
ort_option.intra_op_num_threads = thread_num; ort_option.intra_op_num_threads = thread_num;
openvino_option.cpu_thread_num = thread_num; openvino_option.cpu_thread_num = thread_num;
paddle_infer_option.cpu_thread_num = thread_num;
} }
void RuntimeOption::SetOrtGraphOptLevel(int level) { void RuntimeOption::SetOrtGraphOptLevel(int level) {
@@ -174,25 +175,47 @@ void RuntimeOption::UseLiteBackend() {
} }
void RuntimeOption::SetPaddleMKLDNN(bool pd_mkldnn) { void RuntimeOption::SetPaddleMKLDNN(bool pd_mkldnn) {
pd_enable_mkldnn = pd_mkldnn; FDWARNING << "`RuntimeOption::SetPaddleMKLDNN` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_mkldnn = true`"
<< std::endl;
paddle_infer_option.enable_mkldnn = pd_mkldnn;
} }
void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) { void RuntimeOption::DeletePaddleBackendPass(const std::string& pass_name) {
pd_delete_pass_names.push_back(pass_name); FDWARNING
<< "`RuntimeOption::DeletePaddleBackendPass` will be removed in v1.2.0, "
"please use `option.paddle_infer_option.DeletePass` instead."
<< std::endl;
paddle_infer_option.DeletePass(pass_name);
}
void RuntimeOption::EnablePaddleLogInfo() {
FDWARNING << "`RuntimeOption::EnablePaddleLogInfo` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_log_info = true`"
<< std::endl;
paddle_infer_option.enable_log_info = true;
} }
void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; } void RuntimeOption::DisablePaddleLogInfo() {
FDWARNING << "`RuntimeOption::DisablePaddleLogInfo` will be removed in "
"v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_log_info = false`"
<< std::endl;
paddle_infer_option.enable_log_info = false;
}
void RuntimeOption::EnablePaddleToTrt() { void RuntimeOption::EnablePaddleToTrt() {
FDASSERT(backend == Backend::TRT,
"Should call UseTrtBackend() before call EnablePaddleToTrt().");
#ifdef ENABLE_PADDLE_BACKEND #ifdef ENABLE_PADDLE_BACKEND
FDWARNING << "`RuntimeOption::EnablePaddleToTrt` will be removed in v1.2.0, "
"please modify its member variable directly, e.g "
"`option.paddle_infer_option.enable_trt = true`"
<< std::endl;
FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will " FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will "
"change to use Paddle Inference Backend." "change to use Paddle Inference Backend."
<< std::endl; << std::endl;
backend = Backend::PDINFER; backend = Backend::PDINFER;
pd_enable_trt = true; paddle_infer_option.enable_trt = true;
#else #else
FDASSERT(false, FDASSERT(false,
"While using TrtBackend with EnablePaddleToTrt, require the " "While using TrtBackend with EnablePaddleToTrt, require the "
@@ -202,8 +225,11 @@ void RuntimeOption::EnablePaddleToTrt() {
} }
void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) { void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) {
FDASSERT(size > 0, "Parameter size must greater than 0."); FDWARNING << "`RuntimeOption::SetPaddleMKLDNNCacheSize` will be removed in "
pd_mkldnn_cache_size = size; "v1.2.0, please modify its member variable directly, e.g "
"`option.paddle_infer_option.mkldnn_cache_size = size`."
<< std::endl;
paddle_infer_option.mkldnn_cache_size = size;
} }
void RuntimeOption::SetOpenVINODevice(const std::string& name) { void RuntimeOption::SetOpenVINODevice(const std::string& name) {
@@ -393,12 +419,28 @@ void RuntimeOption::SetOpenVINOStreams(int num_streams) {
openvino_option.num_streams = num_streams; openvino_option.num_streams = num_streams;
} }
void RuntimeOption::EnablePaddleTrtCollectShape() { pd_collect_shape = true; } void RuntimeOption::EnablePaddleTrtCollectShape() {
FDWARNING << "`RuntimeOption::EnablePaddleTrtCollectShape` will be removed "
"in v1.2.0, please modify its member variable directly, e.g "
"runtime_option.paddle_infer_option.collect_trt_shape = true`."
<< std::endl;
paddle_infer_option.collect_trt_shape = true;
}
void RuntimeOption::DisablePaddleTrtCollectShape() { pd_collect_shape = false; } void RuntimeOption::DisablePaddleTrtCollectShape() {
FDWARNING << "`RuntimeOption::DisablePaddleTrtCollectShape` will be removed "
"in v1.2.0, please modify its member variable directly, e.g "
"runtime_option.paddle_infer_option.collect_trt_shape = false`."
<< std::endl;
paddle_infer_option.collect_trt_shape = false;
}
void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) { void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); FDWARNING << "`RuntimeOption::DisablePaddleTrtOps` will be removed in "
"v.1.20, please use "
"`runtime_option.paddle_infer_option.DisableTrtOps` instead."
<< std::endl;
paddle_infer_option.DisableTrtOps(ops);
} }
void RuntimeOption::UseIpu(int device_num, int micro_batch_size, void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
@@ -419,10 +461,11 @@ void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
void RuntimeOption::SetIpuConfig(bool enable_fp16, int replica_num, void RuntimeOption::SetIpuConfig(bool enable_fp16, int replica_num,
float available_memory_proportion, float available_memory_proportion,
bool enable_half_partial) { bool enable_half_partial) {
ipu_enable_fp16 = enable_fp16; paddle_infer_option.ipu_option.ipu_enable_fp16 = enable_fp16;
ipu_replica_num = replica_num; paddle_infer_option.ipu_option.ipu_replica_num = replica_num;
ipu_available_memory_proportion = available_memory_proportion; paddle_infer_option.ipu_option.ipu_available_memory_proportion =
ipu_enable_half_partial = enable_half_partial; available_memory_proportion;
paddle_infer_option.ipu_option.ipu_enable_half_partial = enable_half_partial;
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -378,27 +378,12 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Option to configure ONNX Runtime backend /// Option to configure ONNX Runtime backend
OrtBackendOption ort_option; OrtBackendOption ort_option;
// ======Only for Paddle Backend=====
bool pd_enable_mkldnn = true;
bool pd_enable_log_info = false;
bool pd_enable_trt = false;
bool pd_collect_shape = false;
int pd_mkldnn_cache_size = 1;
std::vector<std::string> pd_delete_pass_names;
// ======Only for Paddle IPU Backend =======
int ipu_device_num = 1;
int ipu_micro_batch_size = 1;
bool ipu_enable_pipelining = false;
int ipu_batches_per_step = 1;
bool ipu_enable_fp16 = false;
int ipu_replica_num = 1;
float ipu_available_memory_proportion = 1.0;
bool ipu_enable_half_partial = false;
/// Option to configure TensorRT backend /// Option to configure TensorRT backend
TrtBackendOption trt_option; TrtBackendOption trt_option;
/// Option to configure Paddle Inference backend
PaddleBackendOption paddle_infer_option;
// ======Only for PaddleTrt Backend======= // ======Only for PaddleTrt Backend=======
std::vector<std::string> trt_disabled_ops_{}; std::vector<std::string> trt_disabled_ops_{};

View File

@@ -39,3 +39,5 @@ from . import text
from . import encryption from . import encryption
from .download import download, download_and_decompress, download_model, get_model_list from .download import download, download_and_decompress, download_model, get_model_list
from . import serving from . import serving
from .code_version import version, git_version
__version__ = version

View File

@@ -364,7 +364,10 @@ class RuntimeOption:
def set_paddle_mkldnn(self, use_mkldnn=True): def set_paddle_mkldnn(self, use_mkldnn=True):
"""Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default. """Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default.
""" """
return self._option.set_paddle_mkldnn(use_mkldnn) logging.warning(
"`RuntimeOption.set_paddle_mkldnn` will be derepcated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_mkldnn = True` instead."
)
self._option.paddle_infer_option.enable_mkldnn = True
def set_openvino_device(self, name="CPU"): def set_openvino_device(self, name="CPU"):
"""Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'.... """Set device name for OpenVINO, default 'CPU', can also be 'AUTO', 'GPU', 'GPU.1'....
@@ -400,17 +403,26 @@ class RuntimeOption:
def enable_paddle_log_info(self): def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
""" """
return self._option.enable_paddle_log_info() logging.warning(
"RuntimeOption.enable_paddle_log_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_log_info = True` instead."
)
self._option.paddle_infer_option.enable_log_info = True
def disable_paddle_log_info(self): def disable_paddle_log_info(self):
"""Disable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. """Disable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
""" """
return self._option.disable_paddle_log_info() logging.warning(
"RuntimeOption.disable_paddle_log_info` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.enable_log_info = False` instead."
)
self._option.paddle_infer_option.enable_log_info = False
def set_paddle_mkldnn_cache_size(self, cache_size): def set_paddle_mkldnn_cache_size(self, cache_size):
"""Set size of shape cache while using Paddle Inference backend with MKLDNN enabled, default will cache all the dynamic shape. """Set size of shape cache while using Paddle Inference backend with MKLDNN enabled, default will cache all the dynamic shape.
""" """
return self._option.set_paddle_mkldnn_cache_size(cache_size) logging.warning(
"RuntimeOption.set_paddle_mkldnn_cache_size` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.mkldnn_cache_size = {}` instead.".
format(cache_size))
self._option.paddle_infer_option.mkldnn_cache_size = cache_size
def enable_lite_fp16(self): def enable_lite_fp16(self):
"""Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. """Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
@@ -498,6 +510,16 @@ class RuntimeOption:
def enable_paddle_to_trt(self): def enable_paddle_to_trt(self):
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead. """While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
""" """
logging.warning(
"`RuntimeOption.enable_paddle_to_trt` will be deprecated in v1.2.l0, if you want to run tensorrt with Paddle Inference backend, please use the following method, "
)
logging.warning(" ==============================================")
logging.warning(" import fastdeploy as fd")
logging.warning(" option = fd.RuntimeOption()")
logging.warning(" option.use_gpu(0)")
logging.warning(" option.use_paddle_infer_backend()")
logging.warning(" option.paddle_infer_option.enabel_trt = True")
logging.warning(" ==============================================")
return self._option.enable_paddle_to_trt() return self._option.enable_paddle_to_trt()
def set_trt_max_workspace_size(self, trt_max_workspace_size): def set_trt_max_workspace_size(self, trt_max_workspace_size):
@@ -519,22 +541,34 @@ class RuntimeOption:
def enable_paddle_trt_collect_shape(self): def enable_paddle_trt_collect_shape(self):
"""Enable collect subgraph shape information while using Paddle Inference with TensorRT """Enable collect subgraph shape information while using Paddle Inference with TensorRT
""" """
return self._option.enable_paddle_trt_collect_shape() logging.warning(
"`RuntimeOption.enable_paddle_trt_collect_shape` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.collect_trt_shape = True` instead."
)
self._option.paddle_infer_option.collect_trt_shape = True
def disable_paddle_trt_collect_shape(self): def disable_paddle_trt_collect_shape(self):
"""Disable collect subgraph shape information while using Paddle Inference with TensorRT """Disable collect subgraph shape information while using Paddle Inference with TensorRT
""" """
return self._option.disable_paddle_trt_collect_shape() logging.warning(
"`RuntimeOption.disable_paddle_trt_collect_shape` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.collect_trt_shape = False` instead."
)
self._option.paddle_infer_option.collect_trt_shape = False
def delete_paddle_backend_pass(self, pass_name): def delete_paddle_backend_pass(self, pass_name):
"""Delete pass by name in paddle backend """Delete pass by name in paddle backend
""" """
return self._option.delete_paddle_backend_pass(pass_name) logging.warning(
"`RuntimeOption.delete_paddle_backend_pass` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.delete_pass` instead."
)
self._option.paddle_infer_option.delete_pass(pass_name)
def disable_paddle_trt_ops(self, ops): def disable_paddle_trt_ops(self, ops):
"""Disable some ops in paddle trt backend """Disable some ops in paddle trt backend
""" """
return self._option.disable_paddle_trt_ops(ops) logging.warning(
"`RuntimeOption.disable_paddle_trt_ops` will be deprecated in v1.2.0, please use `RuntimeOption.paddle_infer_option.disable_trt_ops()` instead."
)
self._option.disable_trt_ops(ops)
def use_ipu(self, def use_ipu(self,
device_num=1, device_num=1,
@@ -593,6 +627,14 @@ class RuntimeOption:
""" """
return self._option.trt_option return self._option.trt_option
@property
def paddle_infer_option(self):
"""Get PaddleBackendOption object to configure Paddle Inference backend
:return PaddleBackendOption
"""
return self._option.paddle_infer_option
def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50): def enable_profiling(self, inclue_h2d_d2h=False, repeat=100, warmup=50):
"""Set the profile mode as 'true'. """Set the profile mode as 'true'.
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. :param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.

View File

@@ -64,8 +64,7 @@ setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
"OFF") "OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF") "OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF")
"OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF") setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF") setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF")
@@ -80,8 +79,7 @@ setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF") setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")
setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF") setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF")
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda")
"/usr/local/cuda")
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main" setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "") setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
@@ -104,6 +102,7 @@ if os.getenv("CMAKE_CXX_COMPILER", None) is not None:
setup_configs["CMAKE_CXX_COMPILER"] = os.getenv("CMAKE_CXX_COMPILER") setup_configs["CMAKE_CXX_COMPILER"] = os.getenv("CMAKE_CXX_COMPILER")
SRC_DIR = os.path.join(TOP_DIR, PACKAGE_NAME) SRC_DIR = os.path.join(TOP_DIR, PACKAGE_NAME)
PYTHON_SRC_DIR = os.path.join(TOP_DIR, "python", PACKAGE_NAME)
CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'python', '.setuptools-cmake-build') CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'python', '.setuptools-cmake-build')
WINDOWS = (os.name == 'nt') WINDOWS = (os.name == 'nt')
@@ -120,8 +119,7 @@ extras_require = {}
# Default value is set to TRUE\1 to keep the settings same as the current ones. # Default value is set to TRUE\1 to keep the settings same as the current ones.
# However going forward the recomemded way to is to set this to False\0 # However going forward the recomemded way to is to set this to False\0
USE_MSVC_STATIC_RUNTIME = bool( USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx') ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx')
################################################################################ ################################################################################
# Version # Version
@@ -151,8 +149,7 @@ assert CMAKE, 'Could not find "cmake" executable!'
@contextmanager @contextmanager
def cd(path): def cd(path):
if not os.path.isabs(path): if not os.path.isabs(path):
raise RuntimeError('Can only cd to absolute path, got: {}'.format( raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
path))
orig_path = os.getcwd() orig_path = os.getcwd()
os.chdir(path) os.chdir(path)
try: try:
@@ -187,7 +184,7 @@ def get_all_files(dirname):
class create_version(ONNXCommand): class create_version(ONNXCommand):
def run(self): def run(self):
with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f: with open(os.path.join(PYTHON_SRC_DIR, 'code_version.py'), 'w') as f:
f.write( f.write(
dedent('''\ dedent('''\
# This file is generated by setup.py. DO NOT EDIT! # This file is generated by setup.py. DO NOT EDIT!