[Backend]Add DisablePaddleTrtOPs (#788)

* Add DisablePaddleTrtOPs

* Add delete_paddle_backend_pass disable_paddle_trt_ops pybind
This commit is contained in:
Jack Zhou
2022-12-05 10:03:52 +08:00
committed by GitHub
parent 6c31198342
commit 8c2d582925
6 changed files with 267 additions and 220 deletions

View File

@@ -22,24 +22,34 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option; option_ = option;
if (option.use_gpu) { if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
if(option_.external_stream_) { if (option_.external_stream_) {
config_.SetExecStream(option_.external_stream_); config_.SetExecStream(option_.external_stream_);
} }
if (option.enable_trt) { if (option.enable_trt) {
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
config_.Exp_DisableTensorRtOPs(option.trt_disabled_ops_);
auto precision = paddle_infer::PrecisionType::kFloat32; auto precision = paddle_infer::PrecisionType::kFloat32;
if (option.trt_option.enable_fp16) { if (option.trt_option.enable_fp16) {
precision = paddle_infer::PrecisionType::kHalf; precision = paddle_infer::PrecisionType::kHalf;
} }
bool use_static = false; bool use_static = false;
if (option.trt_option.serialize_file != "") { if (option.trt_option.serialize_file != "") {
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl; FDWARNING
<< "Detect that tensorrt cache file has been set to "
<< option.trt_option.serialize_file
<< ", but while enable paddle2trt, please notice that the cache "
"file will save to the directory where paddle model saved."
<< std::endl;
use_static = true; use_static = true;
} }
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, precision, use_static); config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
precision, use_static);
SetTRTDynamicShapeToConfig(option); SetTRTDynamicShapeToConfig(option);
#else #else
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl; FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so "
"will fallback to GPU with Paddle Inference Backend."
<< std::endl;
#endif #endif
} }
} else if (option.use_ipu) { } else if (option.use_ipu) {
@@ -98,39 +108,48 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
if (!ReadBinaryFromFile(model_file, &contents)) { if (!ReadBinaryFromFile(model_file, &contents)) {
return false; return false;
} }
auto reader = auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode // If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode
if (reader.is_quantize_model) { if (reader.is_quantize_model) {
if (option.use_gpu) { if (option.use_gpu) {
FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl; FDWARNING << "The loaded model is a quantized model, while inference on "
"GPU, please use TensorRT backend to get better performance."
<< std::endl;
if (option.enable_trt) { if (option.enable_trt) {
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
bool use_static = false; bool use_static = false;
if (option.trt_option.serialize_file != "") { if (option.trt_option.serialize_file != "") {
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl; FDWARNING
<< "Detect that tensorrt cache file has been set to "
<< option.trt_option.serialize_file
<< ", but while enable paddle2trt, please notice that the cache "
"file will save to the directory where paddle model saved."
<< std::endl;
use_static = true; use_static = true;
} }
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, paddle_infer::PrecisionType::kInt8, use_static, false); config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
paddle_infer::PrecisionType::kInt8,
use_static, false);
SetTRTDynamicShapeToConfig(option); SetTRTDynamicShapeToConfig(option);
#endif #endif
} }
} }
if (option.enable_mkldnn) { if (option.enable_mkldnn) {
config_.EnableMkldnnInt8(); config_.EnableMkldnnInt8();
} else { } else {
FDWARNING << "The loaded model is a quantized model, while inference on CPU, please enable MKLDNN to get better performance." << std::endl; FDWARNING << "The loaded model is a quantized model, while inference on "
"CPU, please enable MKLDNN to get better performance."
<< std::endl;
} }
} }
inputs_desc_.resize(reader.num_inputs); inputs_desc_.resize(reader.num_inputs);
for (int i = 0; i < reader.num_inputs; ++i) { for (int i = 0; i < reader.num_inputs; ++i) {
std::string name(reader.inputs[i].name); std::string name(reader.inputs[i].name);
std::vector<int64_t> shape( std::vector<int64_t> shape(reader.inputs[i].shape,
reader.inputs[i].shape, reader.inputs[i].shape + reader.inputs[i].rank);
reader.inputs[i].shape + reader.inputs[i].rank);
inputs_desc_[i].name = name; inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end()); inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype); inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype);
@@ -138,7 +157,9 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
outputs_desc_.resize(reader.num_outputs); outputs_desc_.resize(reader.num_outputs);
for (int i = 0; i < reader.num_outputs; ++i) { for (int i = 0; i < reader.num_outputs; ++i) {
std::string name(reader.outputs[i].name); std::string name(reader.outputs[i].name);
std::vector<int64_t> shape(reader.outputs[i].shape, reader.outputs[i].shape + reader.outputs[i].rank); std::vector<int64_t> shape(reader.outputs[i].shape,
reader.outputs[i].shape +
reader.outputs[i].rank);
outputs_desc_[i].name = name; outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype); outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
@@ -147,7 +168,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
if (option.collect_shape) { if (option.collect_shape) {
// Set the shape info file. // Set the shape info file.
auto curr_model_dir = GetDirFromPath(model_file); auto curr_model_dir = GetDirFromPath(model_file);
std::string shape_range_info = PathJoin(curr_model_dir, "shape_range_info.pbtxt"); std::string shape_range_info =
PathJoin(curr_model_dir, "shape_range_info.pbtxt");
if (!CheckFileExists(shape_range_info)) { if (!CheckFileExists(shape_range_info)) {
FDINFO << "Start generating shape range info file." << std::endl; FDINFO << "Start generating shape range info file." << std::endl;
paddle_infer::Config analysis_config; paddle_infer::Config analysis_config;
@@ -164,7 +186,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
CollectShapeRun(predictor_tmp.get(), opt_shape); CollectShapeRun(predictor_tmp.get(), opt_shape);
FDINFO << "Finish generating shape range info file." << std::endl; FDINFO << "Finish generating shape range info file." << std::endl;
} }
FDINFO << "Start loading shape range info file "<< shape_range_info << " to set TensorRT dynamic shape." << std::endl; FDINFO << "Start loading shape range info file " << shape_range_info
<< " to set TensorRT dynamic shape." << std::endl;
config_.EnableTunedTensorRtDynamicShape(shape_range_info, false); config_.EnableTunedTensorRtDynamicShape(shape_range_info, false);
} }
#endif #endif
@@ -194,8 +217,7 @@ std::vector<TensorInfo> PaddleBackend::GetOutputInfos() {
} }
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs, bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs, std::vector<FDTensor>* outputs, bool copy_to_fd) {
bool copy_to_fd) {
if (inputs.size() != inputs_desc_.size()) { if (inputs.size() != inputs_desc_.size()) {
FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size() FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size()
<< ") should keep same with the inputs of this model(" << ") should keep same with the inputs of this model("
@@ -211,13 +233,13 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
predictor_->Run(); predictor_->Run();
// output share backend memory only support CPU or GPU // output share backend memory only support CPU or GPU
if(option_.use_ipu) { if (option_.use_ipu) {
copy_to_fd = true; copy_to_fd = true;
} }
outputs->resize(outputs_desc_.size()); outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) { for (size_t i = 0; i < outputs_desc_.size(); ++i) {
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name); auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
if(copy_to_fd) { if (copy_to_fd) {
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory; (*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
} }
PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd); PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd);
@@ -225,47 +247,47 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return true; return true;
} }
std::unique_ptr<BaseBackend> PaddleBackend::Clone(void *stream, int device_id) { std::unique_ptr<BaseBackend> PaddleBackend::Clone(void* stream, int device_id) {
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<PaddleBackend>(); std::unique_ptr<BaseBackend> new_backend =
utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get()); auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
if(device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) { if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
auto clone_option = option_; auto clone_option = option_;
clone_option.gpu_id = device_id; clone_option.gpu_id = device_id;
clone_option.external_stream_ = stream; clone_option.external_stream_ = stream;
casted_backend->InitFromPaddle(clone_option.model_file, casted_backend->InitFromPaddle(clone_option.model_file,
clone_option.params_file, clone_option.params_file, clone_option);
clone_option); FDWARNING << "The target device id:" << device_id
FDWARNING << "The target device id:" << " is different from current device id:" << option_.gpu_id
<< device_id << ", cannot share memory with current engine." << std::endl;
<< " is different from current device id:"
<< option_.gpu_id
<< ", cannot share memory with current engine."
<< std::endl;
return new_backend; return new_backend;
} }
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end()); casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end()); casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
outputs_desc_.end());
casted_backend->predictor_ = std::move(predictor_->Clone(stream)); casted_backend->predictor_ = std::move(predictor_->Clone(stream));
return new_backend; return new_backend;
} }
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) { void PaddleBackend::SetTRTDynamicShapeToConfig(
std::map<std::string, std::vector<int>> max_shape; const PaddleBackendOption& option) {
std::map<std::string, std::vector<int>> min_shape; std::map<std::string, std::vector<int>> max_shape;
std::map<std::string, std::vector<int>> opt_shape; std::map<std::string, std::vector<int>> min_shape;
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape); std::map<std::string, std::vector<int>> opt_shape;
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
if (min_shape.size() > 0) {
FDINFO << "Start setting trt dynamic shape." << std::endl; FDINFO << "Start setting trt dynamic shape." << std::endl;
if (min_shape.size() > 0) { config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
}
FDINFO << "Finish setting trt dynamic shape." << std::endl; FDINFO << "Finish setting trt dynamic shape." << std::endl;
}
} }
void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option, void PaddleBackend::GetDynamicShapeFromOption(
std::map<std::string, std::vector<int>>* max_shape, const PaddleBackendOption& option,
std::map<std::string, std::vector<int>>* min_shape, std::map<std::string, std::vector<int>>* max_shape,
std::map<std::string, std::vector<int>>* opt_shape) const { std::map<std::string, std::vector<int>>* min_shape,
std::map<std::string, std::vector<int>>* opt_shape) const {
auto print_shape = [](const std::vector<int>& shape) -> std::string { auto print_shape = [](const std::vector<int>& shape) -> std::string {
std::ostringstream oss; std::ostringstream oss;
oss << "["; oss << "[";
@@ -281,24 +303,35 @@ void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
for (const auto& item : option.trt_option.min_shape) { for (const auto& item : option.trt_option.min_shape) {
auto max_iter = option.trt_option.max_shape.find(item.first); auto max_iter = option.trt_option.max_shape.find(item.first);
auto opt_iter = option.trt_option.opt_shape.find(item.first); auto opt_iter = option.trt_option.opt_shape.find(item.first);
FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str()); FDASSERT(max_iter != option.trt_option.max_shape.end(),
FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str()); "Cannot find %s in TrtBackendOption::min_shape.",
(*max_shape)[item.first].assign(max_iter->second.begin(), max_iter->second.end()); item.first.c_str());
(*opt_shape)[item.first].assign(opt_iter->second.begin(), opt_iter->second.end()); FDASSERT(opt_iter != option.trt_option.opt_shape.end(),
"Cannot find %s in TrtBackendOption::opt_shape.",
item.first.c_str());
(*max_shape)[item.first].assign(max_iter->second.begin(),
max_iter->second.end());
(*opt_shape)[item.first].assign(opt_iter->second.begin(),
opt_iter->second.end());
(*min_shape)[item.first].assign(item.second.begin(), item.second.end()); (*min_shape)[item.first].assign(item.second.begin(), item.second.end());
FDINFO << item.first << ": the max shape = " << print_shape(max_iter->second) FDINFO << item.first
<< ": the max shape = " << print_shape(max_iter->second)
<< ", the min shape = " << print_shape(item.second) << ", the min shape = " << print_shape(item.second)
<< ", the opt shape = " << print_shape(opt_iter->second) << std::endl; << ", the opt shape = " << print_shape(opt_iter->second)
<< std::endl;
} }
} }
void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor, void PaddleBackend::CollectShapeRun(
paddle_infer::Predictor* predictor,
const std::map<std::string, std::vector<int>>& shape) const { const std::map<std::string, std::vector<int>>& shape) const {
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
auto input_type = predictor->GetInputTypes(); auto input_type = predictor->GetInputTypes();
for(auto name : input_names) { for (auto name : input_names) {
FDASSERT(shape.find(name) != shape.end() && input_type.find(name) != input_type.end(), FDASSERT(shape.find(name) != shape.end() &&
"Paddle Input name [%s] is not one of the trt dynamic shape.", name.c_str()); input_type.find(name) != input_type.end(),
"Paddle Input name [%s] is not one of the trt dynamic shape.",
name.c_str());
auto tensor = predictor->GetInputHandle(name); auto tensor = predictor->GetInputHandle(name);
auto shape_value = shape.at(name); auto shape_value = shape.at(name);
int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1, int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1,
@@ -306,30 +339,30 @@ void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
tensor->Reshape(shape_value); tensor->Reshape(shape_value);
auto dtype = input_type[name]; auto dtype = input_type[name];
switch (dtype) { switch (dtype) {
case paddle_infer::DataType::FLOAT32: { case paddle_infer::DataType::FLOAT32: {
std::vector<float> input_data(shape_num, 1.0); std::vector<float> input_data(shape_num, 1.0);
tensor->CopyFromCpu(input_data.data()); tensor->CopyFromCpu(input_data.data());
break; break;
} }
case paddle_infer::DataType::INT32: { case paddle_infer::DataType::INT32: {
std::vector<int> input_data(shape_num, 1); std::vector<int> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data()); tensor->CopyFromCpu(input_data.data());
break; break;
} }
case paddle_infer::DataType::INT64: { case paddle_infer::DataType::INT64: {
std::vector<int64_t> input_data(shape_num, 1); std::vector<int64_t> input_data(shape_num, 1);
tensor->CopyFromCpu(input_data.data()); tensor->CopyFromCpu(input_data.data());
break; break;
} }
default: { default: {
FDASSERT(false, "Input data Paddle backend only supports FP32/INT32/INT64 currently."); FDASSERT(false, "Input data Paddle backend only supports "
break; "FP32/INT32/INT64 currently.");
} break;
}
} }
} }
predictor->Run(); predictor->Run();
} }
#endif #endif
} // namespace fastdeploy } // namespace fastdeploy

26
fastdeploy/backends/paddle/paddle_backend.h Executable file → Normal file
View File

@@ -23,8 +23,8 @@
#ifdef ENABLE_PADDLE_FRONTEND #ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h" #include "paddle2onnx/converter.h"
#endif #endif
#include "paddle_inference_api.h" // NOLINT
#include "fastdeploy/utils/unique_ptr.h" #include "fastdeploy/utils/unique_ptr.h"
#include "paddle_inference_api.h" // NOLINT
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
#include "fastdeploy/backends/tensorrt/trt_backend.h" #include "fastdeploy/backends/tensorrt/trt_backend.h"
@@ -60,6 +60,7 @@ struct PaddleBackendOption {
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
TrtBackendOption trt_option; TrtBackendOption trt_option;
bool collect_shape = false; bool collect_shape = false;
std::vector<std::string> trt_disabled_ops_{};
#endif #endif
#ifdef WITH_IPU #ifdef WITH_IPU
@@ -91,8 +92,7 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
// if copy_to_fd is true, copy memory data to FDTensor // if copy_to_fd is true, copy memory data to FDTensor
/// else share memory to FDTensor /// else share memory to FDTensor
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor, void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
FDTensor* fd_tensor, FDTensor* fd_tensor, bool copy_to_fd);
bool copy_to_fd);
// Convert data type from paddle inference to fastdeploy // Convert data type from paddle inference to fastdeploy
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype); FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype);
@@ -106,20 +106,18 @@ class PaddleBackend : public BaseBackend {
virtual ~PaddleBackend() = default; virtual ~PaddleBackend() = default;
void BuildOption(const PaddleBackendOption& option); void BuildOption(const PaddleBackendOption& option);
bool InitFromPaddle( bool
const std::string& model_file, const std::string& params_file, InitFromPaddle(const std::string& model_file, const std::string& params_file,
const PaddleBackendOption& option = PaddleBackendOption()); const PaddleBackendOption& option = PaddleBackendOption());
bool Infer(std::vector<FDTensor>& inputs, bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
std::vector<FDTensor>* outputs,
bool copy_to_fd = true) override; bool copy_to_fd = true) override;
int NumInputs() const override { return inputs_desc_.size(); } int NumInputs() const override { return inputs_desc_.size(); }
int NumOutputs() const override { return outputs_desc_.size(); } int NumOutputs() const override { return outputs_desc_.size(); }
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr, std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
int device_id = -1) override; int device_id = -1) override;
TensorInfo GetInputInfo(int index) override; TensorInfo GetInputInfo(int index) override;
@@ -129,9 +127,11 @@ class PaddleBackend : public BaseBackend {
private: private:
#ifdef ENABLE_TRT_BACKEND #ifdef ENABLE_TRT_BACKEND
void CollectShapeRun(paddle_infer::Predictor* predictor, void
const std::map<std::string, std::vector<int>>& shape) const; CollectShapeRun(paddle_infer::Predictor* predictor,
void GetDynamicShapeFromOption(const PaddleBackendOption& option, const std::map<std::string, std::vector<int>>& shape) const;
void GetDynamicShapeFromOption(
const PaddleBackendOption& option,
std::map<std::string, std::vector<int>>* max_shape, std::map<std::string, std::vector<int>>* max_shape,
std::map<std::string, std::vector<int>>* min_shape, std::map<std::string, std::vector<int>>* min_shape,
std::map<std::string, std::vector<int>>* opt_shape) const; std::map<std::string, std::vector<int>>* opt_shape) const;

View File

@@ -35,7 +35,8 @@ void BindRuntime(pybind11::module& m) {
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo) .def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators) .def("set_openvino_cpu_operators",
&RuntimeOption::SetOpenVINOCpuOperators)
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
.def("set_paddle_mkldnn_cache_size", .def("set_paddle_mkldnn_cache_size",
@@ -52,10 +53,15 @@ void BindRuntime(pybind11::module& m) {
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape) .def("enable_paddle_trt_collect_shape",
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape) &RuntimeOption::EnablePaddleTrtCollectShape)
.def("disable_paddle_trt_collect_shape",
&RuntimeOption::DisablePaddleTrtCollectShape)
.def("use_ipu", &RuntimeOption::UseIpu) .def("use_ipu", &RuntimeOption::UseIpu)
.def("set_ipu_config", &RuntimeOption::SetIpuConfig) .def("set_ipu_config", &RuntimeOption::SetIpuConfig)
.def("delete_paddle_backend_pass",
&RuntimeOption::DeletePaddleBackendPass)
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
.def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("model_file", &RuntimeOption::model_file)
.def_readwrite("params_file", &RuntimeOption::params_file) .def_readwrite("params_file", &RuntimeOption::params_file)
.def_readwrite("model_format", &RuntimeOption::model_format) .def_readwrite("model_format", &RuntimeOption::model_format)
@@ -117,9 +123,9 @@ void BindRuntime(pybind11::module& m) {
auto dtype = auto dtype =
NumpyDataTypeToFDDataType(warm_datas[i][j].dtype()); NumpyDataTypeToFDDataType(warm_datas[i][j].dtype());
std::vector<int64_t> data_shape; std::vector<int64_t> data_shape;
data_shape.insert( data_shape.insert(data_shape.begin(), warm_datas[i][j].shape(),
data_shape.begin(), warm_datas[i][j].shape(), warm_datas[i][j].shape() +
warm_datas[i][j].shape() + warm_datas[i][j].ndim()); warm_datas[i][j].ndim());
warm_tensors[i][j].Resize(data_shape, dtype); warm_tensors[i][j].Resize(data_shape, dtype);
memcpy(warm_tensors[i][j].MutableData(), memcpy(warm_tensors[i][j].MutableData(),
warm_datas[i][j].mutable_data(), warm_datas[i][j].mutable_data(),
@@ -160,36 +166,39 @@ void BindRuntime(pybind11::module& m) {
} }
return results; return results;
}) })
.def("infer", [](Runtime& self, std::map<std::string, FDTensor>& data) { .def("infer",
std::vector<FDTensor> inputs; [](Runtime& self, std::map<std::string, FDTensor>& data) {
inputs.reserve(data.size()); std::vector<FDTensor> inputs;
for (auto iter = data.begin(); iter != data.end(); ++iter) { inputs.reserve(data.size());
FDTensor tensor; for (auto iter = data.begin(); iter != data.end(); ++iter) {
tensor.SetExternalData(iter->second.Shape(), iter->second.Dtype(), iter->second.Data(), iter->second.device); FDTensor tensor;
tensor.name = iter->first; tensor.SetExternalData(iter->second.Shape(),
inputs.push_back(tensor); iter->second.Dtype(), iter->second.Data(),
} iter->second.device);
std::vector<FDTensor> outputs; tensor.name = iter->first;
if (!self.Infer(inputs, &outputs)) { inputs.push_back(tensor);
throw std::runtime_error("Failed to inference with Runtime."); }
} std::vector<FDTensor> outputs;
return outputs; if (!self.Infer(inputs, &outputs)) {
}) throw std::runtime_error("Failed to inference with Runtime.");
.def("infer", [](Runtime& self, std::vector<FDTensor>& inputs) { }
std::vector<FDTensor> outputs; return outputs;
return self.Infer(inputs, &outputs); })
}) .def("infer",
[](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs;
return self.Infer(inputs, &outputs);
})
.def("bind_input_tensor", &Runtime::BindInputTensor) .def("bind_input_tensor", &Runtime::BindInputTensor)
.def("infer", [](Runtime& self) { .def("infer", [](Runtime& self) { self.Infer(); })
self.Infer(); .def("get_output_tensor",
}) [](Runtime& self, const std::string& name) {
.def("get_output_tensor", [](Runtime& self, const std::string& name) { FDTensor* output = self.GetOutputTensor(name);
FDTensor* output = self.GetOutputTensor(name); if (output == nullptr) {
if(output == nullptr) { return pybind11::cast(nullptr);
return pybind11::cast(nullptr); }
} return pybind11::cast(*output);
return pybind11::cast(*output); })
})
.def("num_inputs", &Runtime::NumInputs) .def("num_inputs", &Runtime::NumInputs)
.def("num_outputs", &Runtime::NumOutputs) .def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo) .def("get_input_info", &Runtime::GetInputInfo)

137
fastdeploy/runtime.cc Executable file → Normal file
View File

@@ -94,7 +94,7 @@ std::string Str(const Backend& b) {
return "Backend::POROS"; return "Backend::POROS";
} else if (b == Backend::RKNPU2) { } else if (b == Backend::RKNPU2) {
return "Backend::RKNPU2"; return "Backend::RKNPU2";
}else if (b == Backend::OPENVINO) { } else if (b == Backend::OPENVINO) {
return "Backend::OPENVINO"; return "Backend::OPENVINO";
} else if (b == Backend::LITE) { } else if (b == Backend::LITE) {
return "Backend::PDLITE"; return "Backend::PDLITE";
@@ -113,7 +113,7 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
out << "Backend::OPENVINO"; out << "Backend::OPENVINO";
} else if (backend == Backend::RKNPU2) { } else if (backend == Backend::RKNPU2) {
out << "Backend::RKNPU2"; out << "Backend::RKNPU2";
}else if (backend == Backend::POROS) { } else if (backend == Backend::POROS) {
out << "Backend::POROS"; out << "Backend::POROS";
} else if (backend == Backend::LITE) { } else if (backend == Backend::LITE) {
out << "Backend::PDLITE"; out << "Backend::PDLITE";
@@ -152,15 +152,17 @@ bool CheckModelFormat(const std::string& model_file,
} else if (model_format == ModelFormat::TORCHSCRIPT) { } else if (model_format == ModelFormat::TORCHSCRIPT) {
if (model_file.size() < 3 || if (model_file.size() < 3 ||
model_file.substr(model_file.size() - 3, 3) != ".pt") { model_file.substr(model_file.size() - 3, 3) != ".pt") {
FDERROR << "With model format of ModelFormat::TORCHSCRIPT, the model file " FDERROR
"should ends with `.pt`, but now it's " << "With model format of ModelFormat::TORCHSCRIPT, the model file "
<< model_file << std::endl; "should ends with `.pt`, but now it's "
<< model_file << std::endl;
return false; return false;
} }
} else { } else {
FDERROR << "Only support model format with frontend ModelFormat::PADDLE / " FDERROR
"ModelFormat::ONNX / ModelFormat::RKNN / ModelFormat::TORCHSCRIPT." << "Only support model format with frontend ModelFormat::PADDLE / "
<< std::endl; "ModelFormat::ONNX / ModelFormat::RKNN / ModelFormat::TORCHSCRIPT."
<< std::endl;
return false; return false;
} }
return true; return true;
@@ -205,9 +207,9 @@ void RuntimeOption::SetModelPath(const std::string& model_path,
model_file = model_path; model_file = model_path;
model_format = ModelFormat::TORCHSCRIPT; model_format = ModelFormat::TORCHSCRIPT;
} else { } else {
FDASSERT( FDASSERT(false,
false, "The model format only can be "
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT."); "ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT.");
} }
} }
@@ -317,13 +319,18 @@ void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; } void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; }
void RuntimeOption::EnablePaddleToTrt() { void RuntimeOption::EnablePaddleToTrt() {
FDASSERT(backend == Backend::TRT, "Should call UseTrtBackend() before call EnablePaddleToTrt()."); FDASSERT(backend == Backend::TRT,
"Should call UseTrtBackend() before call EnablePaddleToTrt().");
#ifdef ENABLE_PADDLE_BACKEND #ifdef ENABLE_PADDLE_BACKEND
FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will change to use Paddle Inference Backend." << std::endl; FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will "
"change to use Paddle Inference Backend."
<< std::endl;
backend = Backend::PDINFER; backend = Backend::PDINFER;
pd_enable_trt = true; pd_enable_trt = true;
#else #else
FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the FastDeploy is compiled with Paddle Inference Backend, please rebuild your FastDeploy."); FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the "
"FastDeploy is compiled with Paddle Inference Backend, "
"please rebuild your FastDeploy.");
#endif #endif
} }
@@ -336,20 +343,12 @@ void RuntimeOption::SetOpenVINODevice(const std::string& name) {
openvino_device = name; openvino_device = name;
} }
void RuntimeOption::EnableLiteFP16() { void RuntimeOption::EnableLiteFP16() { lite_enable_fp16 = true; }
lite_enable_fp16 = true;
}
void RuntimeOption::DisableLiteFP16() { void RuntimeOption::DisableLiteFP16() { lite_enable_fp16 = false; }
lite_enable_fp16 = false; void RuntimeOption::EnableLiteInt8() { lite_enable_int8 = true; }
}
void RuntimeOption::EnableLiteInt8() {
lite_enable_int8 = true;
}
void RuntimeOption::DisableLiteInt8() { void RuntimeOption::DisableLiteInt8() { lite_enable_int8 = false; }
lite_enable_int8 = false;
}
void RuntimeOption::SetLitePowerMode(LitePowerMode mode) { void RuntimeOption::SetLitePowerMode(LitePowerMode mode) {
lite_power_mode = mode; lite_power_mode = mode;
} }
@@ -361,7 +360,8 @@ void RuntimeOption::SetLiteOptimizedModelDir(
void RuntimeOption::SetLiteSubgraphPartitionPath( void RuntimeOption::SetLiteSubgraphPartitionPath(
const std::string& nnadapter_subgraph_partition_config_path) { const std::string& nnadapter_subgraph_partition_config_path) {
lite_nnadapter_subgraph_partition_config_path = nnadapter_subgraph_partition_config_path; lite_nnadapter_subgraph_partition_config_path =
nnadapter_subgraph_partition_config_path;
} }
void RuntimeOption::SetTrtInputShape(const std::string& input_name, void RuntimeOption::SetTrtInputShape(const std::string& input_name,
@@ -387,8 +387,8 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name,
void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) { void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) {
trt_max_workspace_size = max_workspace_size; trt_max_workspace_size = max_workspace_size;
} }
void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size){ void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) {
trt_max_batch_size = max_batch_size; trt_max_batch_size = max_batch_size;
} }
void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }
@@ -422,27 +422,27 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
poros_option.enable_fp16 = option.trt_enable_fp16; poros_option.enable_fp16 = option.trt_enable_fp16;
poros_option.max_batch_size = option.trt_max_batch_size; poros_option.max_batch_size = option.trt_max_batch_size;
poros_option.max_workspace_size = option.trt_max_workspace_size; poros_option.max_workspace_size = option.trt_max_workspace_size;
FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT, FDASSERT(
"PorosBackend only support model format of ModelFormat::TORCHSCRIPT."); option.model_format == ModelFormat::TORCHSCRIPT,
"PorosBackend only support model format of ModelFormat::TORCHSCRIPT.");
backend_ = utils::make_unique<PorosBackend>(); backend_ = utils::make_unique<PorosBackend>();
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get()); auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
FDASSERT( FDASSERT(
casted_backend->Compile(option.model_file, prewarm_tensors, poros_option), casted_backend->Compile(option.model_file, prewarm_tensors, poros_option),
"Load model from Torchscript failed while initliazing PorosBackend."); "Load model from Torchscript failed while initliazing PorosBackend.");
#else #else
FDASSERT(false, FDASSERT(false, "PorosBackend is not available, please compiled with "
"PorosBackend is not available, please compiled with " "ENABLE_POROS_BACKEND=ON.");
"ENABLE_POROS_BACKEND=ON.");
#endif #endif
return true; return true;
} }
void RuntimeOption::EnablePaddleTrtCollectShape() { void RuntimeOption::EnablePaddleTrtCollectShape() { pd_collect_shape = true; }
pd_collect_shape = true;
}
void RuntimeOption::DisablePaddleTrtCollectShape() { void RuntimeOption::DisablePaddleTrtCollectShape() { pd_collect_shape = false; }
pd_collect_shape = false;
void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) {
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
} }
void RuntimeOption::UseIpu(int device_num, int micro_batch_size, void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
@@ -519,9 +519,9 @@ bool Runtime::Init(const RuntimeOption& _option) {
} else if (option.backend == Backend::POROS) { } else if (option.backend == Backend::POROS) {
FDASSERT(option.device == Device::CPU || option.device == Device::GPU, FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
"Backend::POROS only supports Device::CPU/Device::GPU."); "Backend::POROS only supports Device::CPU/Device::GPU.");
FDASSERT( FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT,
option.model_format == ModelFormat::TORCHSCRIPT, "Backend::POROS only supports model format of "
"Backend::POROS only supports model format of ModelFormat::TORCHSCRIPT."); "ModelFormat::TORCHSCRIPT.");
FDINFO << "Runtime initialized with Backend::POROS in " FDINFO << "Runtime initialized with Backend::POROS in "
<< Str(option.device) << "." << std::endl; << Str(option.device) << "." << std::endl;
return true; return true;
@@ -572,7 +572,7 @@ std::vector<TensorInfo> Runtime::GetOutputInfos() {
bool Runtime::Infer(std::vector<FDTensor>& input_tensors, bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
std::vector<FDTensor>* output_tensors) { std::vector<FDTensor>* output_tensors) {
for (auto& tensor: input_tensors) { for (auto& tensor : input_tensors) {
FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id, FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id,
"Device id of input tensor(%d) and runtime(%d) are not same.", "Device id of input tensor(%d) and runtime(%d) are not same.",
tensor.device_id, option.device_id); tensor.device_id, option.device_id);
@@ -589,17 +589,15 @@ void Runtime::BindInputTensor(const std::string& name, FDTensor& input) {
for (auto& t : input_tensors_) { for (auto& t : input_tensors_) {
if (t.name == name) { if (t.name == name) {
is_exist = true; is_exist = true;
t.SetExternalData(input.shape, input.dtype, t.SetExternalData(input.shape, input.dtype, input.MutableData(),
input.MutableData(), input.device, input.device, input.device_id);
input.device_id);
break; break;
} }
} }
if(!is_exist) { if (!is_exist) {
FDTensor new_tensor(name); FDTensor new_tensor(name);
new_tensor.SetExternalData(input.shape, input.dtype, new_tensor.SetExternalData(input.shape, input.dtype, input.MutableData(),
input.MutableData(), input.device, input.device, input.device_id);
input.device_id);
input_tensors_.emplace_back(std::move(new_tensor)); input_tensors_.emplace_back(std::move(new_tensor));
} }
} }
@@ -644,6 +642,7 @@ void Runtime::CreatePaddleBackend() {
trt_option.serialize_file = option.trt_serialize_file; trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory; trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_option = trt_option; pd_option.trt_option = trt_option;
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
} }
#endif #endif
#ifdef WITH_IPU #ifdef WITH_IPU
@@ -669,9 +668,8 @@ void Runtime::CreatePaddleBackend() {
pd_option), pd_option),
"Load model from Paddle failed while initliazing PaddleBackend."); "Load model from Paddle failed while initliazing PaddleBackend.");
#else #else
FDASSERT(false, FDASSERT(false, "PaddleBackend is not available, please compiled with "
"PaddleBackend is not available, please compiled with " "ENABLE_PADDLE_BACKEND=ON.");
"ENABLE_PADDLE_BACKEND=ON.");
#endif #endif
} }
@@ -701,9 +699,8 @@ void Runtime::CreateOpenVINOBackend() {
"Load model from Paddle failed while initliazing OrtBackend."); "Load model from Paddle failed while initliazing OrtBackend.");
} }
#else #else
FDASSERT(false, FDASSERT(false, "OpenVINOBackend is not available, please compiled with "
"OpenVINOBackend is not available, please compiled with " "ENABLE_OPENVINO_BACKEND=ON.");
"ENABLE_OPENVINO_BACKEND=ON.");
#endif #endif
} }
@@ -733,9 +730,8 @@ void Runtime::CreateOrtBackend() {
"Load model from Paddle failed while initliazing OrtBackend."); "Load model from Paddle failed while initliazing OrtBackend.");
} }
#else #else
FDASSERT(false, FDASSERT(false, "OrtBackend is not available, please compiled with "
"OrtBackend is not available, please compiled with " "ENABLE_ORT_BACKEND=ON.");
"ENABLE_ORT_BACKEND=ON.");
#endif #endif
} }
@@ -772,9 +768,8 @@ void Runtime::CreateTrtBackend() {
"Load model from Paddle failed while initliazing TrtBackend."); "Load model from Paddle failed while initliazing TrtBackend.");
} }
#else #else
FDASSERT(false, FDASSERT(false, "TrtBackend is not available, please compiled with "
"TrtBackend is not available, please compiled with " "ENABLE_TRT_BACKEND=ON.");
"ENABLE_TRT_BACKEND=ON.");
#endif #endif
} }
@@ -786,7 +781,8 @@ void Runtime::CreateLiteBackend() {
lite_option.enable_fp16 = option.lite_enable_fp16; lite_option.enable_fp16 = option.lite_enable_fp16;
lite_option.power_mode = static_cast<int>(option.lite_power_mode); lite_option.power_mode = static_cast<int>(option.lite_power_mode);
lite_option.optimized_model_dir = option.lite_optimized_model_dir; lite_option.optimized_model_dir = option.lite_optimized_model_dir;
lite_option.nnadapter_subgraph_partition_config_path = option.lite_nnadapter_subgraph_partition_config_path; lite_option.nnadapter_subgraph_partition_config_path =
option.lite_nnadapter_subgraph_partition_config_path;
lite_option.enable_timvx = option.enable_timvx; lite_option.enable_timvx = option.enable_timvx;
FDASSERT(option.model_format == ModelFormat::PADDLE, FDASSERT(option.model_format == ModelFormat::PADDLE,
"LiteBackend only support model format of ModelFormat::PADDLE"); "LiteBackend only support model format of ModelFormat::PADDLE");
@@ -796,9 +792,8 @@ void Runtime::CreateLiteBackend() {
lite_option), lite_option),
"Load model from nb file failed while initializing LiteBackend."); "Load model from nb file failed while initializing LiteBackend.");
#else #else
FDASSERT(false, FDASSERT(false, "LiteBackend is not available, please compiled with "
"LiteBackend is not available, please compiled with " "ENABLE_LITE_BACKEND=ON.");
"ENABLE_LITE_BACKEND=ON.");
#endif #endif
} }
@@ -821,10 +816,8 @@ void Runtime::CreateRKNPU2Backend() {
Runtime* Runtime::Clone(void* stream, int device_id) { Runtime* Runtime::Clone(void* stream, int device_id) {
Runtime* runtime = new Runtime(); Runtime* runtime = new Runtime();
if (option.backend != Backend::OPENVINO if (option.backend != Backend::OPENVINO &&
&& option.backend != Backend::PDINFER option.backend != Backend::PDINFER && option.backend != Backend::TRT) {
&& option.backend != Backend::TRT
) {
runtime->Init(option); runtime->Init(option);
FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \ FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \
clone engine to reduce CPU/GPU memory usage now. For " clone engine to reduce CPU/GPU memory usage now. For "
@@ -834,8 +827,8 @@ Runtime* Runtime::Clone(void* stream, int device_id) {
<< std::endl; << std::endl;
return runtime; return runtime;
} }
FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " << Str(option.device) FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in "
<< "." << std::endl; << Str(option.device) << "." << std::endl;
runtime->option = option; runtime->option = option;
runtime->backend_ = backend_->Clone(stream, device_id); runtime->backend_ = backend_->Clone(stream, device_id);
return runtime; return runtime;

View File

@@ -24,9 +24,9 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include "backends/rknpu/rknpu2/rknpu2_config.h"
#include "fastdeploy/backends/backend.h" #include "fastdeploy/backends/backend.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "backends/rknpu/rknpu2/rknpu2_config.h"
/** \brief All C++ FastDeploy APIs are defined inside this namespace /** \brief All C++ FastDeploy APIs are defined inside this namespace
* *
@@ -35,14 +35,14 @@ namespace fastdeploy {
/*! Inference backend supported in FastDeploy */ /*! Inference backend supported in FastDeploy */
enum Backend { enum Backend {
UNKNOWN, ///< Unknown inference backend UNKNOWN, ///< Unknown inference backend
ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU
TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only
PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU
POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU
OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only
LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only
RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only
}; };
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
@@ -94,10 +94,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Use Nvidia GPU to inference /// Use Nvidia GPU to inference
void UseGpu(int gpu_id = 0); void UseGpu(int gpu_id = 0);
void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name =
= fastdeploy::rknpu2::CpuName::RK3588, fastdeploy::rknpu2::CpuName::RK3588,
fastdeploy::rknpu2::CoreMask rknpu2_core fastdeploy::rknpu2::CoreMask rknpu2_core =
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0); fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0);
/// Use TimVX to inference /// Use TimVX to inference
void UseTimVX(); void UseTimVX();
@@ -116,9 +116,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
void UsePaddleBackend(); void UsePaddleBackend();
/// Wrapper function of UsePaddleBackend() /// Wrapper function of UsePaddleBackend()
void UsePaddleInferBackend() { void UsePaddleInferBackend() { return UsePaddleBackend(); }
return UsePaddleBackend();
}
/// Set ONNX Runtime as inference backend, support CPU/GPU /// Set ONNX Runtime as inference backend, support CPU/GPU
void UseOrtBackend(); void UseOrtBackend();
@@ -136,9 +134,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
void UseLiteBackend(); void UseLiteBackend();
/// Wrapper function of UseLiteBackend() /// Wrapper function of UseLiteBackend()
void UsePaddleLiteBackend() { void UsePaddleLiteBackend() { return UseLiteBackend(); }
return UseLiteBackend();
}
/// Set mkldnn switch while using Paddle Inference as inference backend /// Set mkldnn switch while using Paddle Inference as inference backend
void SetPaddleMKLDNN(bool pd_mkldnn = true); void SetPaddleMKLDNN(bool pd_mkldnn = true);
@@ -177,7 +173,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
* @brief Set shape info for OpenVINO * @brief Set shape info for OpenVINO
*/ */
void SetOpenVINOShapeInfo( void SetOpenVINOShapeInfo(
const std::map<std::string, std::vector<int64_t>>& shape_info) { const std::map<std::string, std::vector<int64_t>>& shape_info) {
ov_shape_infos = shape_info; ov_shape_infos = shape_info;
} }
@@ -197,7 +193,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
* @brief Set nnadapter subgraph partition path for Paddle Lite backend. * @brief Set nnadapter subgraph partition path for Paddle Lite backend.
*/ */
void SetLiteSubgraphPartitionPath( void SetLiteSubgraphPartitionPath(
const std::string& nnadapter_subgraph_partition_config_path); const std::string& nnadapter_subgraph_partition_config_path);
/** /**
* @brief enable half precision while use paddle lite backend * @brief enable half precision while use paddle lite backend
@@ -275,6 +271,11 @@ struct FASTDEPLOY_DECL RuntimeOption {
*/ */
void DisablePaddleTrtCollectShape(); void DisablePaddleTrtCollectShape();
/**
* @brief Prevent ops running in paddle trt backend
*/
void DisablePaddleTrtOPs(const std::vector<std::string>& ops);
/* /*
* @brief Set number of streams by the OpenVINO backends * @brief Set number of streams by the OpenVINO backends
*/ */
@@ -363,6 +364,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool trt_enable_int8 = false; bool trt_enable_int8 = false;
size_t trt_max_batch_size = 32; size_t trt_max_batch_size = 32;
size_t trt_max_workspace_size = 1 << 30; size_t trt_max_workspace_size = 1 << 30;
// ======Only for PaddleTrt Backend=======
std::vector<std::string> trt_disabled_ops_{};
// ======Only for Poros Backend======= // ======Only for Poros Backend=======
bool is_dynamic = false; bool is_dynamic = false;
@@ -378,12 +381,12 @@ struct FASTDEPLOY_DECL RuntimeOption {
std::vector<std::string> ov_cpu_operators; std::vector<std::string> ov_cpu_operators;
// ======Only for RKNPU2 Backend======= // ======Only for RKNPU2 Backend=======
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ =
= fastdeploy::rknpu2::CpuName::RK3588; fastdeploy::rknpu2::CpuName::RK3588;
fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ =
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO;
std::string model_file = ""; // Path of model file std::string model_file = ""; // Path of model file
std::string params_file = ""; // Path of parameters file, can be empty std::string params_file = ""; // Path of parameters file, can be empty
// format of input model // format of input model
ModelFormat model_format = ModelFormat::AUTOREC; ModelFormat model_format = ModelFormat::AUTOREC;
@@ -450,8 +453,7 @@ struct FASTDEPLOY_DECL Runtime {
* \param[in] stream CUDA Stream, defualt param is nullptr * \param[in] stream CUDA Stream, defualt param is nullptr
* \return new Runtime* by this clone * \return new Runtime* by this clone
*/ */
Runtime* Clone(void* stream = nullptr, Runtime* Clone(void* stream = nullptr, int device_id = -1);
int device_id = -1);
RuntimeOption option; RuntimeOption option;

View File

@@ -435,6 +435,16 @@ class RuntimeOption:
""" """
return self._option.disable_paddle_trt_collect_shape() return self._option.disable_paddle_trt_collect_shape()
def delete_paddle_backend_pass(self, pass_name):
"""Delete pass by name in paddle backend
"""
return self._option.delete_paddle_backend_pass(pass_name)
def disable_paddle_trt_ops(self, ops):
"""Disable some ops in paddle trt backend
"""
return self._option.disable_paddle_trt_ops(ops)
def use_ipu(self, def use_ipu(self,
device_num=1, device_num=1,
micro_batch_size=1, micro_batch_size=1,