[Backend]Add DisablePaddleTrtOPs (#788)

* Add DisablePaddleTrtOPs

* Add delete_paddle_backend_pass disable_paddle_trt_ops pybind
This commit is contained in:
Jack Zhou
2022-12-05 10:03:52 +08:00
committed by GitHub
parent 6c31198342
commit 8c2d582925
6 changed files with 267 additions and 220 deletions

View File

@@ -22,24 +22,34 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option;
if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
if(option_.external_stream_) {
if (option_.external_stream_) {
config_.SetExecStream(option_.external_stream_);
}
if (option.enable_trt) {
#ifdef ENABLE_TRT_BACKEND
config_.Exp_DisableTensorRtOPs(option.trt_disabled_ops_);
auto precision = paddle_infer::PrecisionType::kFloat32;
if (option.trt_option.enable_fp16) {
precision = paddle_infer::PrecisionType::kHalf;
}
bool use_static = false;
if (option.trt_option.serialize_file != "") {
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl;
FDWARNING
<< "Detect that tensorrt cache file has been set to "
<< option.trt_option.serialize_file
<< ", but while enable paddle2trt, please notice that the cache "
"file will save to the directory where paddle model saved."
<< std::endl;
use_static = true;
}
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, precision, use_static);
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
precision, use_static);
SetTRTDynamicShapeToConfig(option);
#else
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl;
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so "
"will fallback to GPU with Paddle Inference Backend."
<< std::endl;
#endif
}
} else if (option.use_ipu) {
@@ -98,39 +108,48 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
if (!ReadBinaryFromFile(model_file, &contents)) {
return false;
}
auto reader =
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode
if (reader.is_quantize_model) {
if (option.use_gpu) {
FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl;
FDWARNING << "The loaded model is a quantized model, while inference on "
"GPU, please use TensorRT backend to get better performance."
<< std::endl;
if (option.enable_trt) {
#ifdef ENABLE_TRT_BACKEND
bool use_static = false;
if (option.trt_option.serialize_file != "") {
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl;
FDWARNING
<< "Detect that tensorrt cache file has been set to "
<< option.trt_option.serialize_file
<< ", but while enable paddle2trt, please notice that the cache "
"file will save to the directory where paddle model saved."
<< std::endl;
use_static = true;
}
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, paddle_infer::PrecisionType::kInt8, use_static, false);
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
paddle_infer::PrecisionType::kInt8,
use_static, false);
SetTRTDynamicShapeToConfig(option);
#endif
}
}
if (option.enable_mkldnn) {
config_.EnableMkldnnInt8();
} else {
FDWARNING << "The loaded model is a quantized model, while inference on CPU, please enable MKLDNN to get better performance." << std::endl;
FDWARNING << "The loaded model is a quantized model, while inference on "
"CPU, please enable MKLDNN to get better performance."
<< std::endl;
}
}
inputs_desc_.resize(reader.num_inputs);
for (int i = 0; i < reader.num_inputs; ++i) {
std::string name(reader.inputs[i].name);
std::vector<int64_t> shape(
reader.inputs[i].shape,
reader.inputs[i].shape + reader.inputs[i].rank);
std::vector<int64_t> shape(reader.inputs[i].shape,
reader.inputs[i].shape + reader.inputs[i].rank);
inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype);
@@ -138,7 +157,9 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
outputs_desc_.resize(reader.num_outputs);
for (int i = 0; i < reader.num_outputs; ++i) {
std::string name(reader.outputs[i].name);
std::vector<int64_t> shape(reader.outputs[i].shape, reader.outputs[i].shape + reader.outputs[i].rank);
std::vector<int64_t> shape(reader.outputs[i].shape,
reader.outputs[i].shape +
reader.outputs[i].rank);
outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
@@ -147,7 +168,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
if (option.collect_shape) {
// Set the shape info file.
auto curr_model_dir = GetDirFromPath(model_file);
std::string shape_range_info = PathJoin(curr_model_dir, "shape_range_info.pbtxt");
std::string shape_range_info =
PathJoin(curr_model_dir, "shape_range_info.pbtxt");
if (!CheckFileExists(shape_range_info)) {
FDINFO << "Start generating shape range info file." << std::endl;
paddle_infer::Config analysis_config;
@@ -164,7 +186,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
CollectShapeRun(predictor_tmp.get(), opt_shape);
FDINFO << "Finish generating shape range info file." << std::endl;
}
FDINFO << "Start loading shape range info file "<< shape_range_info << " to set TensorRT dynamic shape." << std::endl;
FDINFO << "Start loading shape range info file " << shape_range_info
<< " to set TensorRT dynamic shape." << std::endl;
config_.EnableTunedTensorRtDynamicShape(shape_range_info, false);
}
#endif
@@ -194,8 +217,7 @@ std::vector<TensorInfo> PaddleBackend::GetOutputInfos() {
}
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs,
bool copy_to_fd) {
std::vector<FDTensor>* outputs, bool copy_to_fd) {
if (inputs.size() != inputs_desc_.size()) {
FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size()
<< ") should keep same with the inputs of this model("
@@ -211,13 +233,13 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
predictor_->Run();
// output share backend memory only support CPU or GPU
if(option_.use_ipu) {
if (option_.use_ipu) {
copy_to_fd = true;
}
outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
if(copy_to_fd) {
if (copy_to_fd) {
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
}
PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd);
@@ -225,47 +247,47 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return true;
}
std::unique_ptr<BaseBackend> PaddleBackend::Clone(void *stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<PaddleBackend>();
std::unique_ptr<BaseBackend> PaddleBackend::Clone(void* stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
if(device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
auto clone_option = option_;
clone_option.gpu_id = device_id;
clone_option.external_stream_ = stream;
casted_backend->InitFromPaddle(clone_option.model_file,
clone_option.params_file,
clone_option);
FDWARNING << "The target device id:"
<< device_id
<< " is different from current device id:"
<< option_.gpu_id
<< ", cannot share memory with current engine."
<< std::endl;
clone_option.params_file, clone_option);
FDWARNING << "The target device id:" << device_id
<< " is different from current device id:" << option_.gpu_id
<< ", cannot share memory with current engine." << std::endl;
return new_backend;
}
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
outputs_desc_.end());
casted_backend->predictor_ = std::move(predictor_->Clone(stream));
return new_backend;
}
#ifdef ENABLE_TRT_BACKEND
void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) {
std::map<std::string, std::vector<int>> max_shape;
std::map<std::string, std::vector<int>> min_shape;
std::map<std::string, std::vector<int>> opt_shape;
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
void PaddleBackend::SetTRTDynamicShapeToConfig(
const PaddleBackendOption& option) {
std::map<std::string, std::vector<int>> max_shape;
std::map<std::string, std::vector<int>> min_shape;
std::map<std::string, std::vector<int>> opt_shape;
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
if (min_shape.size() > 0) {
FDINFO << "Start setting trt dynamic shape." << std::endl;
if (min_shape.size() > 0) {
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
}
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
FDINFO << "Finish setting trt dynamic shape." << std::endl;
}
}
void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
std::map<std::string, std::vector<int>>* max_shape,
std::map<std::string, std::vector<int>>* min_shape,
std::map<std::string, std::vector<int>>* opt_shape) const {
void PaddleBackend::GetDynamicShapeFromOption(
const PaddleBackendOption& option,
std::map<std::string, std::vector<int>>* max_shape,
std::map<std::string, std::vector<int>>* min_shape,
std::map<std::string, std::vector<int>>* opt_shape) const {
auto print_shape = [](const std::vector<int>& shape) -> std::string {
std::ostringstream oss;
oss << "[";
@@ -281,24 +303,35 @@ void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
for (const auto& item : option.trt_option.min_shape) {
auto max_iter = option.trt_option.max_shape.find(item.first);
auto opt_iter = option.trt_option.opt_shape.find(item.first);
FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str());
FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str());
(*max_shape)[item.first].assign(max_iter->second.begin(), max_iter->second.end());
(*opt_shape)[item.first].assign(opt_iter->second.begin(), opt_iter->second.end());
FDASSERT(max_iter != option.trt_option.max_shape.end(),
"Cannot find %s in TrtBackendOption::min_shape.",
item.first.c_str());
FDASSERT(opt_iter != option.trt_option.opt_shape.end(),
"Cannot find %s in TrtBackendOption::opt_shape.",
item.first.c_str());
(*max_shape)[item.first].assign(max_iter->second.begin(),
max_iter->second.end());
(*opt_shape)[item.first].assign(opt_iter->second.begin(),
opt_iter->second.end());
(*min_shape)[item.first].assign(item.second.begin(), item.second.end());
FDINFO << item.first << ": the max shape = " << print_shape(max_iter->second)
FDINFO << item.first
<< ": the max shape = " << print_shape(max_iter->second)
<< ", the min shape = " << print_shape(item.second)
<< ", the opt shape = " << print_shape(opt_iter->second) << std::endl;
<< ", the opt shape = " << print_shape(opt_iter->second)
<< std::endl;
}
}
void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
void PaddleBackend::CollectShapeRun(
paddle_infer::Predictor* predictor,
const std::map<std::string, std::vector<int>>& shape) const {
auto input_names = predictor->GetInputNames();
auto input_type = predictor->GetInputTypes();
for(auto name : input_names) {
FDASSERT(shape.find(name) != shape.end() && input_type.find(name) != input_type.end(),
"Paddle Input name [%s] is not one of the trt dynamic shape.", name.c_str());
for (auto name : input_names) {
FDASSERT(shape.find(name) != shape.end() &&
input_type.find(name) != input_type.end(),
"Paddle Input name [%s] is not one of the trt dynamic shape.",
name.c_str());
auto tensor = predictor->GetInputHandle(name);
auto shape_value = shape.at(name);
int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1,
@@ -306,30 +339,30 @@ void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
tensor->Reshape(shape_value);
auto dtype = input_type[name];
switch (dtype) {
case paddle_infer::DataType::FLOAT32: {
std::vector<float> input_data(shape_num, 1.0);
tensor->CopyFromCpu(input_data.data());
break;
}
case paddle_infer::DataType::INT32: {
std::vector<int> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data());
break;
}
case paddle_infer::DataType::INT64: {
std::vector<int64_t> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data());
break;
}
default: {
FDASSERT(false, "Input data Paddle backend only supports FP32/INT32/INT64 currently.");
break;
}
case paddle_infer::DataType::FLOAT32: {
std::vector<float> input_data(shape_num, 1.0);
tensor->CopyFromCpu(input_data.data());
break;
}
case paddle_infer::DataType::INT32: {
std::vector<int> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data());
break;
}
case paddle_infer::DataType::INT64: {
std::vector<int64_t> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data());
break;
}
default: {
FDASSERT(false, "Input data Paddle backend only supports "
"FP32/INT32/INT64 currently.");
break;
}
}
}
predictor->Run();
}
#endif
} // namespace fastdeploy