[Other] Optimize runtime module (#1356)

* Optimize runtime

* fix error

* [Backend] Add option to print tensorrt conversion log (#1386)

Add option to print tensorrt conversion log

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

---------

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
This commit is contained in:
Jason
2023-02-21 17:01:32 +08:00
committed by GitHub
parent 42817ddc18
commit 18e33bae5c
7 changed files with 70 additions and 67 deletions

View File

@@ -30,6 +30,9 @@ struct TrtBackendOption {
/// `max_workspace_size` for TensorRT
size_t max_workspace_size = 1 << 30;
/// Enable log while converting onnx model to tensorrt
bool enable_log_info = false;
/*
* @brief Enable half precison inference, on some device not support half precision, it will fallback to float32 mode
*/

View File

@@ -21,6 +21,7 @@ void BindTrtOption(pybind11::module& m) {
pybind11::class_<TrtBackendOption>(m, "TrtBackendOption")
.def(pybind11::init())
.def_readwrite("enable_fp16", &TrtBackendOption::enable_fp16)
.def_readwrite("enable_log_info", &TrtBackendOption::enable_log_info)
.def_readwrite("max_batch_size", &TrtBackendOption::max_batch_size)
.def_readwrite("max_workspace_size",
&TrtBackendOption::max_workspace_size)

View File

@@ -114,6 +114,13 @@ bool TrtBackend::LoadTrtCache(const std::string& trt_engine_file) {
}
bool TrtBackend::Init(const RuntimeOption& runtime_option) {
auto trt_option = runtime_option.trt_option;
trt_option.model_file = runtime_option.model_file;
trt_option.params_file = runtime_option.params_file;
trt_option.model_format = runtime_option.model_format;
trt_option.gpu_id = runtime_option.device_id;
trt_option.enable_pinned_memory = runtime_option.enable_pinned_memory;
trt_option.external_stream_ = runtime_option.external_stream_;
if (runtime_option.device != Device::GPU) {
FDERROR << "TrtBackend only supports Device::GPU, but now it's "
<< runtime_option.device << "." << std::endl;
@@ -130,7 +137,7 @@ bool TrtBackend::Init(const RuntimeOption& runtime_option) {
if (runtime_option.model_from_memory_) {
return InitFromPaddle(runtime_option.model_file,
runtime_option.params_file,
runtime_option.trt_option);
trt_option);
} else {
std::string model_buffer;
std::string params_buffer;
@@ -141,17 +148,17 @@ bool TrtBackend::Init(const RuntimeOption& runtime_option) {
"Failed to read parameters file %s.",
runtime_option.params_file.c_str());
return InitFromPaddle(model_buffer, params_buffer,
runtime_option.trt_option);
trt_option);
}
} else {
if (runtime_option.model_from_memory_) {
return InitFromOnnx(runtime_option.model_file, runtime_option.trt_option);
return InitFromOnnx(runtime_option.model_file, trt_option);
} else {
std::string model_buffer;
FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
"Failed to read model file %s.",
runtime_option.model_file.c_str());
return InitFromOnnx(model_buffer, runtime_option.trt_option);
return InitFromOnnx(model_buffer, trt_option);
}
}
return true;
@@ -525,6 +532,9 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
}
bool TrtBackend::BuildTrtEngine() {
if (option_.enable_log_info) {
FDTrtLogger::Get()->SetLog(true, true);
}
auto config =
FDUniquePtr<nvinfer1::IBuilderConfig>(builder_->createBuilderConfig());
if (!config) {

View File

@@ -220,20 +220,30 @@ class FDTrtLogger : public nvinfer1::ILogger {
logger = new FDTrtLogger();
return logger;
}
void SetLog(bool enable_info = false, bool enable_warning = false) {
enable_info_ = enable_info;
enable_warning_ = enable_warning;
}
void log(nvinfer1::ILogger::Severity severity,
const char* msg) noexcept override {
if (severity == nvinfer1::ILogger::Severity::kINFO) {
// Disable this log
// FDINFO << msg << std::endl;
if (enable_info_) {
FDINFO << msg << std::endl;
}
} else if (severity == nvinfer1::ILogger::Severity::kWARNING) {
// Disable this log
// FDWARNING << msg << std::endl;
if (enable_warning_) {
FDWARNING << msg << std::endl;
}
} else if (severity == nvinfer1::ILogger::Severity::kERROR) {
FDERROR << msg << std::endl;
} else if (severity == nvinfer1::ILogger::Severity::kINTERNAL_ERROR) {
FDASSERT(false, "%s", msg);
}
}
private:
bool enable_info_ = false;
bool enable_warning_ = false;
};
struct ShapeRangeInfo {