mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Backend]Add DisablePaddleTrtOPs (#788)
* Add DisablePaddleTrtOPs * Add delete_paddle_backend_pass disable_paddle_trt_ops pybind
This commit is contained in:
@@ -27,19 +27,29 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
}
|
||||
if (option.enable_trt) {
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
config_.Exp_DisableTensorRtOPs(option.trt_disabled_ops_);
|
||||
auto precision = paddle_infer::PrecisionType::kFloat32;
|
||||
if (option.trt_option.enable_fp16) {
|
||||
precision = paddle_infer::PrecisionType::kHalf;
|
||||
}
|
||||
bool use_static = false;
|
||||
if (option.trt_option.serialize_file != "") {
|
||||
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl;
|
||||
FDWARNING
|
||||
<< "Detect that tensorrt cache file has been set to "
|
||||
<< option.trt_option.serialize_file
|
||||
<< ", but while enable paddle2trt, please notice that the cache "
|
||||
"file will save to the directory where paddle model saved."
|
||||
<< std::endl;
|
||||
use_static = true;
|
||||
}
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, precision, use_static);
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||
option.trt_option.max_batch_size, 3,
|
||||
precision, use_static);
|
||||
SetTRTDynamicShapeToConfig(option);
|
||||
#else
|
||||
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl;
|
||||
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so "
|
||||
"will fallback to GPU with Paddle Inference Backend."
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
} else if (option.use_ipu) {
|
||||
@@ -98,38 +108,47 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
if (!ReadBinaryFromFile(model_file, &contents)) {
|
||||
return false;
|
||||
}
|
||||
auto reader =
|
||||
paddle2onnx::PaddleReader(contents.c_str(), contents.size());
|
||||
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
|
||||
|
||||
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode
|
||||
if (reader.is_quantize_model) {
|
||||
if (option.use_gpu) {
|
||||
FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl;
|
||||
FDWARNING << "The loaded model is a quantized model, while inference on "
|
||||
"GPU, please use TensorRT backend to get better performance."
|
||||
<< std::endl;
|
||||
if (option.enable_trt) {
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
bool use_static = false;
|
||||
if (option.trt_option.serialize_file != "") {
|
||||
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl;
|
||||
FDWARNING
|
||||
<< "Detect that tensorrt cache file has been set to "
|
||||
<< option.trt_option.serialize_file
|
||||
<< ", but while enable paddle2trt, please notice that the cache "
|
||||
"file will save to the directory where paddle model saved."
|
||||
<< std::endl;
|
||||
use_static = true;
|
||||
}
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, paddle_infer::PrecisionType::kInt8, use_static, false);
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||
option.trt_option.max_batch_size, 3,
|
||||
paddle_infer::PrecisionType::kInt8,
|
||||
use_static, false);
|
||||
SetTRTDynamicShapeToConfig(option);
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if (option.enable_mkldnn) {
|
||||
config_.EnableMkldnnInt8();
|
||||
} else {
|
||||
FDWARNING << "The loaded model is a quantized model, while inference on CPU, please enable MKLDNN to get better performance." << std::endl;
|
||||
FDWARNING << "The loaded model is a quantized model, while inference on "
|
||||
"CPU, please enable MKLDNN to get better performance."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
inputs_desc_.resize(reader.num_inputs);
|
||||
for (int i = 0; i < reader.num_inputs; ++i) {
|
||||
std::string name(reader.inputs[i].name);
|
||||
std::vector<int64_t> shape(
|
||||
reader.inputs[i].shape,
|
||||
std::vector<int64_t> shape(reader.inputs[i].shape,
|
||||
reader.inputs[i].shape + reader.inputs[i].rank);
|
||||
inputs_desc_[i].name = name;
|
||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
@@ -138,7 +157,9 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
outputs_desc_.resize(reader.num_outputs);
|
||||
for (int i = 0; i < reader.num_outputs; ++i) {
|
||||
std::string name(reader.outputs[i].name);
|
||||
std::vector<int64_t> shape(reader.outputs[i].shape, reader.outputs[i].shape + reader.outputs[i].rank);
|
||||
std::vector<int64_t> shape(reader.outputs[i].shape,
|
||||
reader.outputs[i].shape +
|
||||
reader.outputs[i].rank);
|
||||
outputs_desc_[i].name = name;
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
|
||||
@@ -147,7 +168,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
if (option.collect_shape) {
|
||||
// Set the shape info file.
|
||||
auto curr_model_dir = GetDirFromPath(model_file);
|
||||
std::string shape_range_info = PathJoin(curr_model_dir, "shape_range_info.pbtxt");
|
||||
std::string shape_range_info =
|
||||
PathJoin(curr_model_dir, "shape_range_info.pbtxt");
|
||||
if (!CheckFileExists(shape_range_info)) {
|
||||
FDINFO << "Start generating shape range info file." << std::endl;
|
||||
paddle_infer::Config analysis_config;
|
||||
@@ -164,7 +186,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
CollectShapeRun(predictor_tmp.get(), opt_shape);
|
||||
FDINFO << "Finish generating shape range info file." << std::endl;
|
||||
}
|
||||
FDINFO << "Start loading shape range info file "<< shape_range_info << " to set TensorRT dynamic shape." << std::endl;
|
||||
FDINFO << "Start loading shape range info file " << shape_range_info
|
||||
<< " to set TensorRT dynamic shape." << std::endl;
|
||||
config_.EnableTunedTensorRtDynamicShape(shape_range_info, false);
|
||||
}
|
||||
#endif
|
||||
@@ -194,8 +217,7 @@ std::vector<TensorInfo> PaddleBackend::GetOutputInfos() {
|
||||
}
|
||||
|
||||
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size()
|
||||
<< ") should keep same with the inputs of this model("
|
||||
@@ -226,43 +248,43 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseBackend> PaddleBackend::Clone(void* stream, int device_id) {
|
||||
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<PaddleBackend>();
|
||||
std::unique_ptr<BaseBackend> new_backend =
|
||||
utils::make_unique<PaddleBackend>();
|
||||
auto casted_backend = dynamic_cast<PaddleBackend*>(new_backend.get());
|
||||
if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) {
|
||||
auto clone_option = option_;
|
||||
clone_option.gpu_id = device_id;
|
||||
clone_option.external_stream_ = stream;
|
||||
casted_backend->InitFromPaddle(clone_option.model_file,
|
||||
clone_option.params_file,
|
||||
clone_option);
|
||||
FDWARNING << "The target device id:"
|
||||
<< device_id
|
||||
<< " is different from current device id:"
|
||||
<< option_.gpu_id
|
||||
<< ", cannot share memory with current engine."
|
||||
<< std::endl;
|
||||
clone_option.params_file, clone_option);
|
||||
FDWARNING << "The target device id:" << device_id
|
||||
<< " is different from current device id:" << option_.gpu_id
|
||||
<< ", cannot share memory with current engine." << std::endl;
|
||||
return new_backend;
|
||||
}
|
||||
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
|
||||
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
|
||||
casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
|
||||
outputs_desc_.end());
|
||||
casted_backend->predictor_ = std::move(predictor_->Clone(stream));
|
||||
return new_backend;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) {
|
||||
void PaddleBackend::SetTRTDynamicShapeToConfig(
|
||||
const PaddleBackendOption& option) {
|
||||
std::map<std::string, std::vector<int>> max_shape;
|
||||
std::map<std::string, std::vector<int>> min_shape;
|
||||
std::map<std::string, std::vector<int>> opt_shape;
|
||||
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
|
||||
FDINFO << "Start setting trt dynamic shape." << std::endl;
|
||||
if (min_shape.size() > 0) {
|
||||
FDINFO << "Start setting trt dynamic shape." << std::endl;
|
||||
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
|
||||
}
|
||||
FDINFO << "Finish setting trt dynamic shape." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
|
||||
void PaddleBackend::GetDynamicShapeFromOption(
|
||||
const PaddleBackendOption& option,
|
||||
std::map<std::string, std::vector<int>>* max_shape,
|
||||
std::map<std::string, std::vector<int>>* min_shape,
|
||||
std::map<std::string, std::vector<int>>* opt_shape) const {
|
||||
@@ -281,24 +303,35 @@ void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
|
||||
for (const auto& item : option.trt_option.min_shape) {
|
||||
auto max_iter = option.trt_option.max_shape.find(item.first);
|
||||
auto opt_iter = option.trt_option.opt_shape.find(item.first);
|
||||
FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str());
|
||||
FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str());
|
||||
(*max_shape)[item.first].assign(max_iter->second.begin(), max_iter->second.end());
|
||||
(*opt_shape)[item.first].assign(opt_iter->second.begin(), opt_iter->second.end());
|
||||
FDASSERT(max_iter != option.trt_option.max_shape.end(),
|
||||
"Cannot find %s in TrtBackendOption::min_shape.",
|
||||
item.first.c_str());
|
||||
FDASSERT(opt_iter != option.trt_option.opt_shape.end(),
|
||||
"Cannot find %s in TrtBackendOption::opt_shape.",
|
||||
item.first.c_str());
|
||||
(*max_shape)[item.first].assign(max_iter->second.begin(),
|
||||
max_iter->second.end());
|
||||
(*opt_shape)[item.first].assign(opt_iter->second.begin(),
|
||||
opt_iter->second.end());
|
||||
(*min_shape)[item.first].assign(item.second.begin(), item.second.end());
|
||||
FDINFO << item.first << ": the max shape = " << print_shape(max_iter->second)
|
||||
FDINFO << item.first
|
||||
<< ": the max shape = " << print_shape(max_iter->second)
|
||||
<< ", the min shape = " << print_shape(item.second)
|
||||
<< ", the opt shape = " << print_shape(opt_iter->second) << std::endl;
|
||||
<< ", the opt shape = " << print_shape(opt_iter->second)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||
void PaddleBackend::CollectShapeRun(
|
||||
paddle_infer::Predictor* predictor,
|
||||
const std::map<std::string, std::vector<int>>& shape) const {
|
||||
auto input_names = predictor->GetInputNames();
|
||||
auto input_type = predictor->GetInputTypes();
|
||||
for (auto name : input_names) {
|
||||
FDASSERT(shape.find(name) != shape.end() && input_type.find(name) != input_type.end(),
|
||||
"Paddle Input name [%s] is not one of the trt dynamic shape.", name.c_str());
|
||||
FDASSERT(shape.find(name) != shape.end() &&
|
||||
input_type.find(name) != input_type.end(),
|
||||
"Paddle Input name [%s] is not one of the trt dynamic shape.",
|
||||
name.c_str());
|
||||
auto tensor = predictor->GetInputHandle(name);
|
||||
auto shape_value = shape.at(name);
|
||||
int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1,
|
||||
@@ -322,7 +355,8 @@ void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
FDASSERT(false, "Input data Paddle backend only supports FP32/INT32/INT64 currently.");
|
||||
FDASSERT(false, "Input data Paddle backend only supports "
|
||||
"FP32/INT32/INT64 currently.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -331,5 +365,4 @@ void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
20
fastdeploy/backends/paddle/paddle_backend.h
Executable file → Normal file
20
fastdeploy/backends/paddle/paddle_backend.h
Executable file → Normal file
@@ -23,8 +23,8 @@
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
#include "paddle2onnx/converter.h"
|
||||
#endif
|
||||
#include "paddle_inference_api.h" // NOLINT
|
||||
#include "fastdeploy/utils/unique_ptr.h"
|
||||
#include "paddle_inference_api.h" // NOLINT
|
||||
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
#include "fastdeploy/backends/tensorrt/trt_backend.h"
|
||||
@@ -60,6 +60,7 @@ struct PaddleBackendOption {
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
TrtBackendOption trt_option;
|
||||
bool collect_shape = false;
|
||||
std::vector<std::string> trt_disabled_ops_{};
|
||||
#endif
|
||||
|
||||
#ifdef WITH_IPU
|
||||
@@ -91,8 +92,7 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
||||
// if copy_to_fd is true, copy memory data to FDTensor
|
||||
/// else share memory to FDTensor
|
||||
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
FDTensor* fd_tensor,
|
||||
bool copy_to_fd);
|
||||
FDTensor* fd_tensor, bool copy_to_fd);
|
||||
|
||||
// Convert data type from paddle inference to fastdeploy
|
||||
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype);
|
||||
@@ -106,15 +106,13 @@ class PaddleBackend : public BaseBackend {
|
||||
virtual ~PaddleBackend() = default;
|
||||
void BuildOption(const PaddleBackendOption& option);
|
||||
|
||||
bool InitFromPaddle(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
bool
|
||||
InitFromPaddle(const std::string& model_file, const std::string& params_file,
|
||||
const PaddleBackendOption& option = PaddleBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
|
||||
int NumInputs() const override { return inputs_desc_.size(); }
|
||||
|
||||
int NumOutputs() const override { return outputs_desc_.size(); }
|
||||
@@ -129,9 +127,11 @@ class PaddleBackend : public BaseBackend {
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_TRT_BACKEND
|
||||
void CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||
void
|
||||
CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||
const std::map<std::string, std::vector<int>>& shape) const;
|
||||
void GetDynamicShapeFromOption(const PaddleBackendOption& option,
|
||||
void GetDynamicShapeFromOption(
|
||||
const PaddleBackendOption& option,
|
||||
std::map<std::string, std::vector<int>>* max_shape,
|
||||
std::map<std::string, std::vector<int>>* min_shape,
|
||||
std::map<std::string, std::vector<int>>* opt_shape) const;
|
||||
|
@@ -35,7 +35,8 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN)
|
||||
.def("set_openvino_device", &RuntimeOption::SetOpenVINODevice)
|
||||
.def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo)
|
||||
.def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators)
|
||||
.def("set_openvino_cpu_operators",
|
||||
&RuntimeOption::SetOpenVINOCpuOperators)
|
||||
.def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo)
|
||||
.def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo)
|
||||
.def("set_paddle_mkldnn_cache_size",
|
||||
@@ -52,10 +53,15 @@ void BindRuntime(pybind11::module& m) {
|
||||
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
||||
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
|
||||
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
|
||||
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
|
||||
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
|
||||
.def("enable_paddle_trt_collect_shape",
|
||||
&RuntimeOption::EnablePaddleTrtCollectShape)
|
||||
.def("disable_paddle_trt_collect_shape",
|
||||
&RuntimeOption::DisablePaddleTrtCollectShape)
|
||||
.def("use_ipu", &RuntimeOption::UseIpu)
|
||||
.def("set_ipu_config", &RuntimeOption::SetIpuConfig)
|
||||
.def("delete_paddle_backend_pass",
|
||||
&RuntimeOption::DeletePaddleBackendPass)
|
||||
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
|
||||
.def_readwrite("model_file", &RuntimeOption::model_file)
|
||||
.def_readwrite("params_file", &RuntimeOption::params_file)
|
||||
.def_readwrite("model_format", &RuntimeOption::model_format)
|
||||
@@ -117,9 +123,9 @@ void BindRuntime(pybind11::module& m) {
|
||||
auto dtype =
|
||||
NumpyDataTypeToFDDataType(warm_datas[i][j].dtype());
|
||||
std::vector<int64_t> data_shape;
|
||||
data_shape.insert(
|
||||
data_shape.begin(), warm_datas[i][j].shape(),
|
||||
warm_datas[i][j].shape() + warm_datas[i][j].ndim());
|
||||
data_shape.insert(data_shape.begin(), warm_datas[i][j].shape(),
|
||||
warm_datas[i][j].shape() +
|
||||
warm_datas[i][j].ndim());
|
||||
warm_tensors[i][j].Resize(data_shape, dtype);
|
||||
memcpy(warm_tensors[i][j].MutableData(),
|
||||
warm_datas[i][j].mutable_data(),
|
||||
@@ -160,12 +166,15 @@ void BindRuntime(pybind11::module& m) {
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("infer", [](Runtime& self, std::map<std::string, FDTensor>& data) {
|
||||
.def("infer",
|
||||
[](Runtime& self, std::map<std::string, FDTensor>& data) {
|
||||
std::vector<FDTensor> inputs;
|
||||
inputs.reserve(data.size());
|
||||
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
||||
FDTensor tensor;
|
||||
tensor.SetExternalData(iter->second.Shape(), iter->second.Dtype(), iter->second.Data(), iter->second.device);
|
||||
tensor.SetExternalData(iter->second.Shape(),
|
||||
iter->second.Dtype(), iter->second.Data(),
|
||||
iter->second.device);
|
||||
tensor.name = iter->first;
|
||||
inputs.push_back(tensor);
|
||||
}
|
||||
@@ -175,15 +184,15 @@ void BindRuntime(pybind11::module& m) {
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("infer", [](Runtime& self, std::vector<FDTensor>& inputs) {
|
||||
.def("infer",
|
||||
[](Runtime& self, std::vector<FDTensor>& inputs) {
|
||||
std::vector<FDTensor> outputs;
|
||||
return self.Infer(inputs, &outputs);
|
||||
})
|
||||
.def("bind_input_tensor", &Runtime::BindInputTensor)
|
||||
.def("infer", [](Runtime& self) {
|
||||
self.Infer();
|
||||
})
|
||||
.def("get_output_tensor", [](Runtime& self, const std::string& name) {
|
||||
.def("infer", [](Runtime& self) { self.Infer(); })
|
||||
.def("get_output_tensor",
|
||||
[](Runtime& self, const std::string& name) {
|
||||
FDTensor* output = self.GetOutputTensor(name);
|
||||
if (output == nullptr) {
|
||||
return pybind11::cast(nullptr);
|
||||
|
103
fastdeploy/runtime.cc
Executable file → Normal file
103
fastdeploy/runtime.cc
Executable file → Normal file
@@ -152,13 +152,15 @@ bool CheckModelFormat(const std::string& model_file,
|
||||
} else if (model_format == ModelFormat::TORCHSCRIPT) {
|
||||
if (model_file.size() < 3 ||
|
||||
model_file.substr(model_file.size() - 3, 3) != ".pt") {
|
||||
FDERROR << "With model format of ModelFormat::TORCHSCRIPT, the model file "
|
||||
FDERROR
|
||||
<< "With model format of ModelFormat::TORCHSCRIPT, the model file "
|
||||
"should ends with `.pt`, but now it's "
|
||||
<< model_file << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
FDERROR << "Only support model format with frontend ModelFormat::PADDLE / "
|
||||
FDERROR
|
||||
<< "Only support model format with frontend ModelFormat::PADDLE / "
|
||||
"ModelFormat::ONNX / ModelFormat::RKNN / ModelFormat::TORCHSCRIPT."
|
||||
<< std::endl;
|
||||
return false;
|
||||
@@ -205,9 +207,9 @@ void RuntimeOption::SetModelPath(const std::string& model_path,
|
||||
model_file = model_path;
|
||||
model_format = ModelFormat::TORCHSCRIPT;
|
||||
} else {
|
||||
FDASSERT(
|
||||
false,
|
||||
"The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT.");
|
||||
FDASSERT(false,
|
||||
"The model format only can be "
|
||||
"ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,13 +319,18 @@ void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; }
|
||||
void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; }
|
||||
|
||||
void RuntimeOption::EnablePaddleToTrt() {
|
||||
FDASSERT(backend == Backend::TRT, "Should call UseTrtBackend() before call EnablePaddleToTrt().");
|
||||
FDASSERT(backend == Backend::TRT,
|
||||
"Should call UseTrtBackend() before call EnablePaddleToTrt().");
|
||||
#ifdef ENABLE_PADDLE_BACKEND
|
||||
FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will change to use Paddle Inference Backend." << std::endl;
|
||||
FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will "
|
||||
"change to use Paddle Inference Backend."
|
||||
<< std::endl;
|
||||
backend = Backend::PDINFER;
|
||||
pd_enable_trt = true;
|
||||
#else
|
||||
FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the FastDeploy is compiled with Paddle Inference Backend, please rebuild your FastDeploy.");
|
||||
FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the "
|
||||
"FastDeploy is compiled with Paddle Inference Backend, "
|
||||
"please rebuild your FastDeploy.");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -336,20 +343,12 @@ void RuntimeOption::SetOpenVINODevice(const std::string& name) {
|
||||
openvino_device = name;
|
||||
}
|
||||
|
||||
void RuntimeOption::EnableLiteFP16() {
|
||||
lite_enable_fp16 = true;
|
||||
}
|
||||
void RuntimeOption::EnableLiteFP16() { lite_enable_fp16 = true; }
|
||||
|
||||
void RuntimeOption::DisableLiteFP16() {
|
||||
lite_enable_fp16 = false;
|
||||
}
|
||||
void RuntimeOption::EnableLiteInt8() {
|
||||
lite_enable_int8 = true;
|
||||
}
|
||||
void RuntimeOption::DisableLiteFP16() { lite_enable_fp16 = false; }
|
||||
void RuntimeOption::EnableLiteInt8() { lite_enable_int8 = true; }
|
||||
|
||||
void RuntimeOption::DisableLiteInt8() {
|
||||
lite_enable_int8 = false;
|
||||
}
|
||||
void RuntimeOption::DisableLiteInt8() { lite_enable_int8 = false; }
|
||||
void RuntimeOption::SetLitePowerMode(LitePowerMode mode) {
|
||||
lite_power_mode = mode;
|
||||
}
|
||||
@@ -361,7 +360,8 @@ void RuntimeOption::SetLiteOptimizedModelDir(
|
||||
|
||||
void RuntimeOption::SetLiteSubgraphPartitionPath(
|
||||
const std::string& nnadapter_subgraph_partition_config_path) {
|
||||
lite_nnadapter_subgraph_partition_config_path = nnadapter_subgraph_partition_config_path;
|
||||
lite_nnadapter_subgraph_partition_config_path =
|
||||
nnadapter_subgraph_partition_config_path;
|
||||
}
|
||||
|
||||
void RuntimeOption::SetTrtInputShape(const std::string& input_name,
|
||||
@@ -422,7 +422,8 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
poros_option.enable_fp16 = option.trt_enable_fp16;
|
||||
poros_option.max_batch_size = option.trt_max_batch_size;
|
||||
poros_option.max_workspace_size = option.trt_max_workspace_size;
|
||||
FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
FDASSERT(
|
||||
option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
"PorosBackend only support model format of ModelFormat::TORCHSCRIPT.");
|
||||
backend_ = utils::make_unique<PorosBackend>();
|
||||
auto casted_backend = dynamic_cast<PorosBackend*>(backend_.get());
|
||||
@@ -430,19 +431,18 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
casted_backend->Compile(option.model_file, prewarm_tensors, poros_option),
|
||||
"Load model from Torchscript failed while initliazing PorosBackend.");
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"PorosBackend is not available, please compiled with "
|
||||
FDASSERT(false, "PorosBackend is not available, please compiled with "
|
||||
"ENABLE_POROS_BACKEND=ON.");
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
void RuntimeOption::EnablePaddleTrtCollectShape() {
|
||||
pd_collect_shape = true;
|
||||
}
|
||||
void RuntimeOption::EnablePaddleTrtCollectShape() { pd_collect_shape = true; }
|
||||
|
||||
void RuntimeOption::DisablePaddleTrtCollectShape() {
|
||||
pd_collect_shape = false;
|
||||
void RuntimeOption::DisablePaddleTrtCollectShape() { pd_collect_shape = false; }
|
||||
|
||||
void RuntimeOption::DisablePaddleTrtOPs(const std::vector<std::string>& ops) {
|
||||
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
|
||||
}
|
||||
|
||||
void RuntimeOption::UseIpu(int device_num, int micro_batch_size,
|
||||
@@ -519,9 +519,9 @@ bool Runtime::Init(const RuntimeOption& _option) {
|
||||
} else if (option.backend == Backend::POROS) {
|
||||
FDASSERT(option.device == Device::CPU || option.device == Device::GPU,
|
||||
"Backend::POROS only supports Device::CPU/Device::GPU.");
|
||||
FDASSERT(
|
||||
option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
"Backend::POROS only supports model format of ModelFormat::TORCHSCRIPT.");
|
||||
FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT,
|
||||
"Backend::POROS only supports model format of "
|
||||
"ModelFormat::TORCHSCRIPT.");
|
||||
FDINFO << "Runtime initialized with Backend::POROS in "
|
||||
<< Str(option.device) << "." << std::endl;
|
||||
return true;
|
||||
@@ -589,17 +589,15 @@ void Runtime::BindInputTensor(const std::string& name, FDTensor& input) {
|
||||
for (auto& t : input_tensors_) {
|
||||
if (t.name == name) {
|
||||
is_exist = true;
|
||||
t.SetExternalData(input.shape, input.dtype,
|
||||
input.MutableData(), input.device,
|
||||
input.device_id);
|
||||
t.SetExternalData(input.shape, input.dtype, input.MutableData(),
|
||||
input.device, input.device_id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_exist) {
|
||||
FDTensor new_tensor(name);
|
||||
new_tensor.SetExternalData(input.shape, input.dtype,
|
||||
input.MutableData(), input.device,
|
||||
input.device_id);
|
||||
new_tensor.SetExternalData(input.shape, input.dtype, input.MutableData(),
|
||||
input.device, input.device_id);
|
||||
input_tensors_.emplace_back(std::move(new_tensor));
|
||||
}
|
||||
}
|
||||
@@ -644,6 +642,7 @@ void Runtime::CreatePaddleBackend() {
|
||||
trt_option.serialize_file = option.trt_serialize_file;
|
||||
trt_option.enable_pinned_memory = option.enable_pinned_memory;
|
||||
pd_option.trt_option = trt_option;
|
||||
pd_option.trt_disabled_ops_ = option.trt_disabled_ops_;
|
||||
}
|
||||
#endif
|
||||
#ifdef WITH_IPU
|
||||
@@ -669,8 +668,7 @@ void Runtime::CreatePaddleBackend() {
|
||||
pd_option),
|
||||
"Load model from Paddle failed while initliazing PaddleBackend.");
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"PaddleBackend is not available, please compiled with "
|
||||
FDASSERT(false, "PaddleBackend is not available, please compiled with "
|
||||
"ENABLE_PADDLE_BACKEND=ON.");
|
||||
#endif
|
||||
}
|
||||
@@ -701,8 +699,7 @@ void Runtime::CreateOpenVINOBackend() {
|
||||
"Load model from Paddle failed while initliazing OrtBackend.");
|
||||
}
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"OpenVINOBackend is not available, please compiled with "
|
||||
FDASSERT(false, "OpenVINOBackend is not available, please compiled with "
|
||||
"ENABLE_OPENVINO_BACKEND=ON.");
|
||||
#endif
|
||||
}
|
||||
@@ -733,8 +730,7 @@ void Runtime::CreateOrtBackend() {
|
||||
"Load model from Paddle failed while initliazing OrtBackend.");
|
||||
}
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"OrtBackend is not available, please compiled with "
|
||||
FDASSERT(false, "OrtBackend is not available, please compiled with "
|
||||
"ENABLE_ORT_BACKEND=ON.");
|
||||
#endif
|
||||
}
|
||||
@@ -772,8 +768,7 @@ void Runtime::CreateTrtBackend() {
|
||||
"Load model from Paddle failed while initliazing TrtBackend.");
|
||||
}
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"TrtBackend is not available, please compiled with "
|
||||
FDASSERT(false, "TrtBackend is not available, please compiled with "
|
||||
"ENABLE_TRT_BACKEND=ON.");
|
||||
#endif
|
||||
}
|
||||
@@ -786,7 +781,8 @@ void Runtime::CreateLiteBackend() {
|
||||
lite_option.enable_fp16 = option.lite_enable_fp16;
|
||||
lite_option.power_mode = static_cast<int>(option.lite_power_mode);
|
||||
lite_option.optimized_model_dir = option.lite_optimized_model_dir;
|
||||
lite_option.nnadapter_subgraph_partition_config_path = option.lite_nnadapter_subgraph_partition_config_path;
|
||||
lite_option.nnadapter_subgraph_partition_config_path =
|
||||
option.lite_nnadapter_subgraph_partition_config_path;
|
||||
lite_option.enable_timvx = option.enable_timvx;
|
||||
FDASSERT(option.model_format == ModelFormat::PADDLE,
|
||||
"LiteBackend only support model format of ModelFormat::PADDLE");
|
||||
@@ -796,8 +792,7 @@ void Runtime::CreateLiteBackend() {
|
||||
lite_option),
|
||||
"Load model from nb file failed while initializing LiteBackend.");
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"LiteBackend is not available, please compiled with "
|
||||
FDASSERT(false, "LiteBackend is not available, please compiled with "
|
||||
"ENABLE_LITE_BACKEND=ON.");
|
||||
#endif
|
||||
}
|
||||
@@ -821,10 +816,8 @@ void Runtime::CreateRKNPU2Backend() {
|
||||
|
||||
Runtime* Runtime::Clone(void* stream, int device_id) {
|
||||
Runtime* runtime = new Runtime();
|
||||
if (option.backend != Backend::OPENVINO
|
||||
&& option.backend != Backend::PDINFER
|
||||
&& option.backend != Backend::TRT
|
||||
) {
|
||||
if (option.backend != Backend::OPENVINO &&
|
||||
option.backend != Backend::PDINFER && option.backend != Backend::TRT) {
|
||||
runtime->Init(option);
|
||||
FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \
|
||||
clone engine to reduce CPU/GPU memory usage now. For "
|
||||
@@ -834,8 +827,8 @@ Runtime* Runtime::Clone(void* stream, int device_id) {
|
||||
<< std::endl;
|
||||
return runtime;
|
||||
}
|
||||
FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " << Str(option.device)
|
||||
<< "." << std::endl;
|
||||
FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in "
|
||||
<< Str(option.device) << "." << std::endl;
|
||||
runtime->option = option;
|
||||
runtime->backend_ = backend_->Clone(stream, device_id);
|
||||
return runtime;
|
||||
|
@@ -24,9 +24,9 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "backends/rknpu/rknpu2/rknpu2_config.h"
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "backends/rknpu/rknpu2/rknpu2_config.h"
|
||||
|
||||
/** \brief All C++ FastDeploy APIs are defined inside this namespace
|
||||
*
|
||||
@@ -94,10 +94,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
/// Use Nvidia GPU to inference
|
||||
void UseGpu(int gpu_id = 0);
|
||||
|
||||
void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name
|
||||
= fastdeploy::rknpu2::CpuName::RK3588,
|
||||
fastdeploy::rknpu2::CoreMask rknpu2_core
|
||||
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0);
|
||||
void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name =
|
||||
fastdeploy::rknpu2::CpuName::RK3588,
|
||||
fastdeploy::rknpu2::CoreMask rknpu2_core =
|
||||
fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0);
|
||||
|
||||
/// Use TimVX to inference
|
||||
void UseTimVX();
|
||||
@@ -116,9 +116,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
void UsePaddleBackend();
|
||||
|
||||
/// Wrapper function of UsePaddleBackend()
|
||||
void UsePaddleInferBackend() {
|
||||
return UsePaddleBackend();
|
||||
}
|
||||
void UsePaddleInferBackend() { return UsePaddleBackend(); }
|
||||
|
||||
/// Set ONNX Runtime as inference backend, support CPU/GPU
|
||||
void UseOrtBackend();
|
||||
@@ -136,9 +134,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
void UseLiteBackend();
|
||||
|
||||
/// Wrapper function of UseLiteBackend()
|
||||
void UsePaddleLiteBackend() {
|
||||
return UseLiteBackend();
|
||||
}
|
||||
void UsePaddleLiteBackend() { return UseLiteBackend(); }
|
||||
|
||||
/// Set mkldnn switch while using Paddle Inference as inference backend
|
||||
void SetPaddleMKLDNN(bool pd_mkldnn = true);
|
||||
@@ -275,6 +271,11 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
*/
|
||||
void DisablePaddleTrtCollectShape();
|
||||
|
||||
/**
|
||||
* @brief Prevent ops running in paddle trt backend
|
||||
*/
|
||||
void DisablePaddleTrtOPs(const std::vector<std::string>& ops);
|
||||
|
||||
/*
|
||||
* @brief Set number of streams by the OpenVINO backends
|
||||
*/
|
||||
@@ -363,6 +364,8 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
bool trt_enable_int8 = false;
|
||||
size_t trt_max_batch_size = 32;
|
||||
size_t trt_max_workspace_size = 1 << 30;
|
||||
// ======Only for PaddleTrt Backend=======
|
||||
std::vector<std::string> trt_disabled_ops_{};
|
||||
|
||||
// ======Only for Poros Backend=======
|
||||
bool is_dynamic = false;
|
||||
@@ -378,10 +381,10 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
||||
std::vector<std::string> ov_cpu_operators;
|
||||
|
||||
// ======Only for RKNPU2 Backend=======
|
||||
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_
|
||||
= fastdeploy::rknpu2::CpuName::RK3588;
|
||||
fastdeploy::rknpu2::CoreMask rknpu2_core_mask_
|
||||
= fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO;
|
||||
fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ =
|
||||
fastdeploy::rknpu2::CpuName::RK3588;
|
||||
fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ =
|
||||
fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO;
|
||||
|
||||
std::string model_file = ""; // Path of model file
|
||||
std::string params_file = ""; // Path of parameters file, can be empty
|
||||
@@ -450,8 +453,7 @@ struct FASTDEPLOY_DECL Runtime {
|
||||
* \param[in] stream CUDA Stream, defualt param is nullptr
|
||||
* \return new Runtime* by this clone
|
||||
*/
|
||||
Runtime* Clone(void* stream = nullptr,
|
||||
int device_id = -1);
|
||||
Runtime* Clone(void* stream = nullptr, int device_id = -1);
|
||||
|
||||
RuntimeOption option;
|
||||
|
||||
|
@@ -435,6 +435,16 @@ class RuntimeOption:
|
||||
"""
|
||||
return self._option.disable_paddle_trt_collect_shape()
|
||||
|
||||
def delete_paddle_backend_pass(self, pass_name):
|
||||
"""Delete pass by name in paddle backend
|
||||
"""
|
||||
return self._option.delete_paddle_backend_pass(pass_name)
|
||||
|
||||
def disable_paddle_trt_ops(self, ops):
|
||||
"""Disable some ops in paddle trt backend
|
||||
"""
|
||||
return self._option.disable_paddle_trt_ops(ops)
|
||||
|
||||
def use_ipu(self,
|
||||
device_num=1,
|
||||
micro_batch_size=1,
|
||||
|
Reference in New Issue
Block a user