mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-17 06:00:59 +08:00
[Other] Optimize runtime module (#1356)
* Optimize runtime * fix error * [Backend] Add option to print tensorrt conversion log (#1386) Add option to print tensorrt conversion log Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com> --------- Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
This commit is contained in:
@@ -30,6 +30,9 @@ struct TrtBackendOption {
|
||||
/// `max_workspace_size` for TensorRT
|
||||
size_t max_workspace_size = 1 << 30;
|
||||
|
||||
/// Enable log while converting onnx model to tensorrt
|
||||
bool enable_log_info = false;
|
||||
|
||||
/*
|
||||
* @brief Enable half precison inference, on some device not support half precision, it will fallback to float32 mode
|
||||
*/
|
||||
|
@@ -21,6 +21,7 @@ void BindTrtOption(pybind11::module& m) {
|
||||
pybind11::class_<TrtBackendOption>(m, "TrtBackendOption")
|
||||
.def(pybind11::init())
|
||||
.def_readwrite("enable_fp16", &TrtBackendOption::enable_fp16)
|
||||
.def_readwrite("enable_log_info", &TrtBackendOption::enable_log_info)
|
||||
.def_readwrite("max_batch_size", &TrtBackendOption::max_batch_size)
|
||||
.def_readwrite("max_workspace_size",
|
||||
&TrtBackendOption::max_workspace_size)
|
||||
|
@@ -114,6 +114,13 @@ bool TrtBackend::LoadTrtCache(const std::string& trt_engine_file) {
|
||||
}
|
||||
|
||||
bool TrtBackend::Init(const RuntimeOption& runtime_option) {
|
||||
auto trt_option = runtime_option.trt_option;
|
||||
trt_option.model_file = runtime_option.model_file;
|
||||
trt_option.params_file = runtime_option.params_file;
|
||||
trt_option.model_format = runtime_option.model_format;
|
||||
trt_option.gpu_id = runtime_option.device_id;
|
||||
trt_option.enable_pinned_memory = runtime_option.enable_pinned_memory;
|
||||
trt_option.external_stream_ = runtime_option.external_stream_;
|
||||
if (runtime_option.device != Device::GPU) {
|
||||
FDERROR << "TrtBackend only supports Device::GPU, but now it's "
|
||||
<< runtime_option.device << "." << std::endl;
|
||||
@@ -130,7 +137,7 @@ bool TrtBackend::Init(const RuntimeOption& runtime_option) {
|
||||
if (runtime_option.model_from_memory_) {
|
||||
return InitFromPaddle(runtime_option.model_file,
|
||||
runtime_option.params_file,
|
||||
runtime_option.trt_option);
|
||||
trt_option);
|
||||
} else {
|
||||
std::string model_buffer;
|
||||
std::string params_buffer;
|
||||
@@ -141,17 +148,17 @@ bool TrtBackend::Init(const RuntimeOption& runtime_option) {
|
||||
"Failed to read parameters file %s.",
|
||||
runtime_option.params_file.c_str());
|
||||
return InitFromPaddle(model_buffer, params_buffer,
|
||||
runtime_option.trt_option);
|
||||
trt_option);
|
||||
}
|
||||
} else {
|
||||
if (runtime_option.model_from_memory_) {
|
||||
return InitFromOnnx(runtime_option.model_file, runtime_option.trt_option);
|
||||
return InitFromOnnx(runtime_option.model_file, trt_option);
|
||||
} else {
|
||||
std::string model_buffer;
|
||||
FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
|
||||
"Failed to read model file %s.",
|
||||
runtime_option.model_file.c_str());
|
||||
return InitFromOnnx(model_buffer, runtime_option.trt_option);
|
||||
return InitFromOnnx(model_buffer, trt_option);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -525,6 +532,9 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
||||
}
|
||||
|
||||
bool TrtBackend::BuildTrtEngine() {
|
||||
if (option_.enable_log_info) {
|
||||
FDTrtLogger::Get()->SetLog(true, true);
|
||||
}
|
||||
auto config =
|
||||
FDUniquePtr<nvinfer1::IBuilderConfig>(builder_->createBuilderConfig());
|
||||
if (!config) {
|
||||
|
@@ -220,20 +220,30 @@ class FDTrtLogger : public nvinfer1::ILogger {
|
||||
logger = new FDTrtLogger();
|
||||
return logger;
|
||||
}
|
||||
void SetLog(bool enable_info = false, bool enable_warning = false) {
|
||||
enable_info_ = enable_info;
|
||||
enable_warning_ = enable_warning;
|
||||
}
|
||||
|
||||
void log(nvinfer1::ILogger::Severity severity,
|
||||
const char* msg) noexcept override {
|
||||
if (severity == nvinfer1::ILogger::Severity::kINFO) {
|
||||
// Disable this log
|
||||
// FDINFO << msg << std::endl;
|
||||
if (enable_info_) {
|
||||
FDINFO << msg << std::endl;
|
||||
}
|
||||
} else if (severity == nvinfer1::ILogger::Severity::kWARNING) {
|
||||
// Disable this log
|
||||
// FDWARNING << msg << std::endl;
|
||||
if (enable_warning_) {
|
||||
FDWARNING << msg << std::endl;
|
||||
}
|
||||
} else if (severity == nvinfer1::ILogger::Severity::kERROR) {
|
||||
FDERROR << msg << std::endl;
|
||||
} else if (severity == nvinfer1::ILogger::Severity::kINTERNAL_ERROR) {
|
||||
FDASSERT(false, "%s", msg);
|
||||
}
|
||||
}
|
||||
private:
|
||||
bool enable_info_ = false;
|
||||
bool enable_warning_ = false;
|
||||
};
|
||||
|
||||
struct ShapeRangeInfo {
|
||||
|
Reference in New Issue
Block a user