mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-19 23:16:15 +08:00
[Other] Optimize paddle backend (#1265)
* Optimize paddle backend * optimize paddle backend * add version support
This commit is contained in:
@@ -24,54 +24,71 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
/*! @brief Option object to configure GraphCore IPU
|
||||
*/
|
||||
struct IpuOption {
|
||||
/// IPU device id
|
||||
int ipu_device_num;
|
||||
/// the batch size in the graph, only work when graph has no batch shape info
|
||||
int ipu_micro_batch_size;
|
||||
/// enable pipelining
|
||||
bool ipu_enable_pipelining;
|
||||
/// the number of batches per run in pipelining
|
||||
int ipu_batches_per_step;
|
||||
/// enable fp16
|
||||
bool ipu_enable_fp16;
|
||||
/// the number of graph replication
|
||||
int ipu_replica_num;
|
||||
/// the available memory proportion for matmul/conv
|
||||
float ipu_available_memory_proportion;
|
||||
/// enable fp16 partial for matmul, only work with fp16
|
||||
bool ipu_enable_half_partial;
|
||||
};
|
||||
|
||||
/*! @brief Option object to configure Paddle Inference backend
|
||||
*/
|
||||
struct PaddleBackendOption {
|
||||
/// Print log information while initialize Paddle Inference backend
|
||||
bool enable_log_info = false;
|
||||
/// Enable MKLDNN while inference on CPU
|
||||
bool enable_mkldnn = true;
|
||||
/// Use Paddle Inference + TensorRT to inference model on GPU
|
||||
bool enable_trt = false;
|
||||
|
||||
/*
|
||||
* @brief IPU option, this will configure the IPU hardware, if inference model in IPU
|
||||
*/
|
||||
IpuOption ipu_option;
|
||||
|
||||
/// Collect shape for model while enabel_trt is true
|
||||
bool collect_trt_shape = false;
|
||||
/// Cache input shape for mkldnn while the input data will change dynamiclly
|
||||
int mkldnn_cache_size = -1;
|
||||
/// initialize memory size(MB) for GPU
|
||||
int gpu_mem_init_size = 100;
|
||||
|
||||
void DisableTrtOps(const std::vector<std::string>& ops) {
|
||||
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
|
||||
}
|
||||
|
||||
void DeletePass(const std::string& pass_name) {
|
||||
delete_pass_names.push_back(pass_name);
|
||||
}
|
||||
|
||||
// The belowing parameters may be removed, please do not
|
||||
// read or write them directly
|
||||
TrtBackendOption trt_option;
|
||||
bool enable_pinned_memory = false;
|
||||
void* external_stream_ = nullptr;
|
||||
Device device = Device::CPU;
|
||||
int device_id = 0;
|
||||
std::vector<std::string> trt_disabled_ops_{};
|
||||
int cpu_thread_num = 8;
|
||||
std::vector<std::string> delete_pass_names = {};
|
||||
std::string model_file = ""; // Path of model file
|
||||
std::string params_file = ""; // Path of parameters file, can be empty
|
||||
|
||||
// load model and paramters from memory
|
||||
bool model_from_memory_ = false;
|
||||
|
||||
#ifdef WITH_GPU
|
||||
bool use_gpu = true;
|
||||
#else
|
||||
bool use_gpu = false;
|
||||
#endif
|
||||
bool enable_mkldnn = true;
|
||||
|
||||
bool enable_log_info = false;
|
||||
|
||||
bool enable_trt = false;
|
||||
TrtBackendOption trt_option;
|
||||
bool collect_shape = false;
|
||||
std::vector<std::string> trt_disabled_ops_{};
|
||||
|
||||
#ifdef WITH_IPU
|
||||
bool use_ipu = true;
|
||||
IpuOption ipu_option;
|
||||
#else
|
||||
bool use_ipu = false;
|
||||
#endif
|
||||
|
||||
int mkldnn_cache_size = 1;
|
||||
int cpu_thread_num = 8;
|
||||
// initialize memory size(MB) for GPU
|
||||
int gpu_mem_init_size = 100;
|
||||
// gpu device id
|
||||
int gpu_id = 0;
|
||||
bool enable_pinned_memory = false;
|
||||
void* external_stream_ = nullptr;
|
||||
|
||||
std::vector<std::string> delete_pass_names = {};
|
||||
};
|
||||
} // namespace fastdeploy
|
||||
|
53
fastdeploy/runtime/backends/paddle/option_pybind.cc
Normal file
53
fastdeploy/runtime/backends/paddle/option_pybind.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
#include "fastdeploy/runtime/backends/paddle/option.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindIpuOption(pybind11::module& m) {
|
||||
pybind11::class_<IpuOption>(m, "IpuOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("ipu_device_num", &IpuOption::ipu_device_num)
|
||||
.def_readwrite("ipu_micro_batch_size", &IpuOption::ipu_micro_batch_size)
|
||||
.def_readwrite("ipu_enable_pipelining", &IpuOption::ipu_enable_pipelining)
|
||||
.def_readwrite("ipu_batches_per_step", &IpuOption::ipu_batches_per_step)
|
||||
.def_readwrite("ipu_enable_fp16", &IpuOption::ipu_enable_fp16)
|
||||
.def_readwrite("ipu_replica_num", &IpuOption::ipu_replica_num)
|
||||
.def_readwrite("ipu_available_memory_proportion",
|
||||
&IpuOption::ipu_available_memory_proportion)
|
||||
.def_readwrite("ipu_enable_half_partial",
|
||||
&IpuOption::ipu_enable_half_partial);
|
||||
}
|
||||
|
||||
void BindPaddleOption(pybind11::module& m) {
|
||||
BindIpuOption(m);
|
||||
pybind11::class_<PaddleBackendOption>(m, "PaddleBackendOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("enable_log_info", &PaddleBackendOption::enable_log_info)
|
||||
.def_readwrite("enable_mkldnn", &PaddleBackendOption::enable_mkldnn)
|
||||
.def_readwrite("enable_trt", &PaddleBackendOption::enable_trt)
|
||||
.def_readwrite("ipu_option", &PaddleBackendOption::ipu_option)
|
||||
.def_readwrite("collect_trt_shape",
|
||||
&PaddleBackendOption::collect_trt_shape)
|
||||
.def_readwrite("mkldnn_cache_size",
|
||||
&PaddleBackendOption::mkldnn_cache_size)
|
||||
.def_readwrite("gpu_mem_init_size",
|
||||
&PaddleBackendOption::gpu_mem_init_size)
|
||||
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
|
||||
.def("delete_pass", &PaddleBackendOption::DeletePass);
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
@@ -22,8 +22,8 @@ namespace fastdeploy {
|
||||
|
||||
void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
option_ = option;
|
||||
if (option.use_gpu) {
|
||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
|
||||
if (option.device == Device::GPU) {
|
||||
config_.EnableUseGpu(option.gpu_mem_init_size, option.device_id);
|
||||
if (option_.external_stream_) {
|
||||
config_.SetExecStream(option_.external_stream_);
|
||||
}
|
||||
@@ -50,7 +50,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
precision, use_static);
|
||||
SetTRTDynamicShapeToConfig(option);
|
||||
}
|
||||
} else if (option.use_ipu) {
|
||||
} else if (option.device == Device::IPU) {
|
||||
#ifdef WITH_IPU
|
||||
config_.EnableIpu(option.ipu_option.ipu_device_num,
|
||||
option.ipu_option.ipu_micro_batch_size,
|
||||
@@ -101,14 +101,15 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
|
||||
params_buffer.c_str(), params_buffer.size());
|
||||
config_.EnableMemoryOptim();
|
||||
BuildOption(option);
|
||||
|
||||
|
||||
// The input/output information get from predictor is not right, use
|
||||
// PaddleReader instead now
|
||||
auto reader = paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size());
|
||||
auto reader =
|
||||
paddle2onnx::PaddleReader(model_buffer.c_str(), model_buffer.size());
|
||||
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to
|
||||
// int8 mode
|
||||
if (reader.is_quantize_model) {
|
||||
if (option.use_gpu) {
|
||||
if (option.device == Device::GPU) {
|
||||
FDWARNING << "The loaded model is a quantized model, while inference on "
|
||||
"GPU, please use TensorRT backend to get better performance."
|
||||
<< std::endl;
|
||||
@@ -158,7 +159,7 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_buffer,
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
|
||||
}
|
||||
if (option.collect_shape) {
|
||||
if (option.collect_trt_shape) {
|
||||
// Set the shape info file.
|
||||
std::string curr_model_dir = "./";
|
||||
if (!option.model_from_memory_) {
|
||||
@@ -221,19 +222,19 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
<< inputs_desc_.size() << ")." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto handle = predictor_->GetInputHandle(inputs[i].name);
|
||||
ShareTensorFromFDTensor(handle.get(), inputs[i]);
|
||||
}
|
||||
|
||||
|
||||
RUNTIME_PROFILE_LOOP_BEGIN(1)
|
||||
predictor_->Run();
|
||||
RUNTIME_PROFILE_LOOP_END
|
||||
|
||||
|
||||
// output share backend memory only support CPU or GPU
|
||||
if (option_.use_ipu) {
|
||||
if (option_.device == Device::IPU) {
|
||||
copy_to_fd = true;
|
||||
}
|
||||
outputs->resize(outputs_desc_.size());
|
||||
@@ -253,9 +254,10 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
|
||||
std::unique_ptr<BaseBackend> new_backend =
|
||||
utils::make_unique<PaddleBackend>();
|
||||
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
|
||||
if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
|
||||
if (device_id > 0 && (option_.device == Device::GPU) &&
|
||||
device_id != option_.device_id) {
|
||||
auto clone_option = option_;
|
||||
clone_option.gpu_id = device_id;
|
||||
clone_option.device_id = device_id;
|
||||
clone_option.external_stream_ = stream;
|
||||
if (runtime_option.model_from_memory_) {
|
||||
FDASSERT(
|
||||
@@ -279,7 +281,7 @@ std::unique_ptr<BaseBackend> PaddleBackend::Clone(RuntimeOption& runtime_option,
|
||||
}
|
||||
|
||||
FDWARNING << "The target device id:" << device_id
|
||||
<< " is different from current device id:" << option_.gpu_id
|
||||
<< " is different from current device id:" << option_.device_id
|
||||
<< ", cannot share memory with current engine." << std::endl;
|
||||
return new_backend;
|
||||
}
|
||||
@@ -347,10 +349,13 @@ void PaddleBackend::CollectShapeRun(
|
||||
const std::map<std::string, std::vector<int>>& shape) const {
|
||||
auto input_names = predictor->GetInputNames();
|
||||
auto input_type = predictor->GetInputTypes();
|
||||
for (auto name : input_names) {
|
||||
for (const auto& name : input_names) {
|
||||
FDASSERT(shape.find(name) != shape.end() &&
|
||||
input_type.find(name) != input_type.end(),
|
||||
"Paddle Input name [%s] is not one of the trt dynamic shape.",
|
||||
"When collect_trt_shape is true, please define max/opt/min shape "
|
||||
"for model's input:[\"%s\"] by "
|
||||
"(C++)RuntimeOption.trt_option.SetShape/"
|
||||
"(Python)RuntimeOption.trt_option.set_shape.",
|
||||
name.c_str());
|
||||
auto tensor = predictor->GetInputHandle(name);
|
||||
auto shape_value = shape.at(name);
|
||||
@@ -385,4 +390,4 @@ void PaddleBackend::CollectShapeRun(
|
||||
predictor->Run();
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace fastdeploy
|
||||
|
Reference in New Issue
Block a user