Optimize TensorRT backend to support rebuild engine (#189)

* optimize tensorrt usage

* format code

* fix input shape error for onnx model

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
This commit is contained in:
Jason
2022-09-06 10:53:05 +08:00
committed by GitHub
parent 4bf0d3847a
commit 969531dcc8
6 changed files with 526 additions and 266 deletions

View File

@@ -13,9 +13,9 @@
// limitations under the License.
#include "fastdeploy/backends/tensorrt/trt_backend.h"
#include <cstring>
#include "NvInferSafeRuntime.h"
#include "fastdeploy/utils/utils.h"
#include <cstring>
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
#endif
@@ -24,117 +24,46 @@ namespace fastdeploy {
FDTrtLogger* FDTrtLogger::logger = nullptr;
size_t TrtDataTypeSize(const nvinfer1::DataType& dtype) {
if (dtype == nvinfer1::DataType::kFLOAT) {
return sizeof(float);
} else if (dtype == nvinfer1::DataType::kHALF) {
return sizeof(float) / 2;
} else if (dtype == nvinfer1::DataType::kINT8) {
return sizeof(int8_t);
} else if (dtype == nvinfer1::DataType::kINT32) {
return sizeof(int32_t);
}
// kBOOL
return sizeof(bool);
}
FDDataType GetFDDataType(const nvinfer1::DataType& dtype) {
if (dtype == nvinfer1::DataType::kFLOAT) {
return FDDataType::FP32;
} else if (dtype == nvinfer1::DataType::kHALF) {
return FDDataType::FP16;
} else if (dtype == nvinfer1::DataType::kINT8) {
return FDDataType::INT8;
} else if (dtype == nvinfer1::DataType::kINT32) {
return FDDataType::INT32;
}
// kBOOL
return FDDataType::BOOL;
}
std::vector<int> toVec(const nvinfer1::Dims& dim) {
std::vector<int> out(dim.d, dim.d + dim.nbDims);
return out;
}
bool CheckDynamicShapeConfig(const paddle2onnx::OnnxReader& reader,
const TrtBackendOption& option) {
// paddle2onnx::ModelTensorInfo inputs[reader.NumInputs()];
// std::string input_shapes[reader.NumInputs()];
std::vector<paddle2onnx::ModelTensorInfo> inputs(reader.NumInputs());
std::vector<std::string> input_shapes(reader.NumInputs());
for (int i = 0; i < reader.NumInputs(); ++i) {
reader.GetInputInfo(i, &inputs[i]);
// change 0 to -1, when input_dim is a string, onnx will make it to zero
for (int j = 0; j < inputs[i].rank; ++j) {
if (inputs[i].shape[j] <= 0) {
inputs[i].shape[j] = -1;
// Check if the model can build tensorrt engine now
// If the model has dynamic input shape, it will require defined shape
// information We can set the shape range information by function
// SetTrtInputShape() But if the shape range is not defined, then the engine
// cannot build, in this case, The engine will build once there's data feeded,
// and the shape range will be updated
bool CanBuildEngine(
const std::map<std::string, ShapeRangeInfo>& shape_range_info) {
for (auto iter = shape_range_info.begin(); iter != shape_range_info.end();
++iter) {
bool is_full_static = true;
for (size_t i = 0; i < iter->second.shape.size(); ++i) {
if (iter->second.shape[i] < 0) {
is_full_static = false;
break;
}
}
input_shapes[i] = "";
for (int j = 0; j < inputs[i].rank; ++j) {
if (j != inputs[i].rank - 1) {
input_shapes[i] += (std::to_string(inputs[i].shape[j]) + ", ");
} else {
input_shapes[i] += std::to_string(inputs[i].shape[j]);
if (is_full_static) {
continue;
}
for (size_t i = 0; i < iter->second.shape.size(); ++i) {
if (iter->second.min[i] < 0 || iter->second.max[i] < 0) {
return false;
}
}
}
bool all_check_passed = true;
for (int i = 0; i < reader.NumInputs(); ++i) {
bool contain_unknown_dim = false;
for (int j = 0; j < inputs[i].rank; ++j) {
if (inputs[i].shape[j] < 0) {
contain_unknown_dim = true;
}
}
std::string name(inputs[i].name, strlen(inputs[i].name));
FDINFO << "The loaded model's input tensor:" << name
<< " has shape [" + input_shapes[i] << "]." << std::endl;
if (contain_unknown_dim) {
auto iter1 = option.min_shape.find(name);
auto iter2 = option.max_shape.find(name);
auto iter3 = option.opt_shape.find(name);
if (iter1 == option.min_shape.end() || iter2 == option.max_shape.end() ||
iter3 == option.opt_shape.end()) {
FDERROR << "The loaded model's input tensor:" << name
<< " has dynamic shape [" + input_shapes[i] +
"], but didn't configure it's shape for tensorrt with "
"SetTrtInputShape correctly."
<< std::endl;
all_check_passed = false;
}
}
}
return all_check_passed;
return true;
}
bool TrtBackend::InitFromTrt(const std::string& trt_engine_file,
const TrtBackendOption& option) {
if (initialized_) {
FDERROR << "TrtBackend is already initlized, cannot initialize again."
<< std::endl;
return false;
}
cudaSetDevice(option.gpu_id);
bool TrtBackend::LoadTrtCache(const std::string& trt_engine_file) {
cudaSetDevice(option_.gpu_id);
std::ifstream fin(trt_engine_file, std::ios::binary | std::ios::in);
if (!fin) {
FDERROR << "Failed to open TensorRT Engine file " << trt_engine_file
<< std::endl;
return false;
}
fin.seekg(0, std::ios::end);
std::string engine_buffer;
engine_buffer.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(engine_buffer.at(0)), engine_buffer.size());
fin.close();
if (!ReadBinaryFromFile(trt_engine_file, &engine_buffer)) {
FDERROR << "Failed to load TensorRT Engine from " << trt_engine_file << "."
<< std::endl;
return false;
}
FDUniquePtr<nvinfer1::IRuntime> runtime{
nvinfer1::createInferRuntime(*FDTrtLogger::Get())};
if (!runtime) {
@@ -152,10 +81,31 @@ bool TrtBackend::InitFromTrt(const std::string& trt_engine_file,
context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
engine_->createExecutionContext());
FDASSERT(cudaStreamCreate(&stream_) == 0,
"[ERROR] Error occurs while calling cudaStreamCreate().");
GetInputOutputInfo();
initialized_ = true;
for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
if (!engine_->bindingIsInput(i)) {
continue;
}
auto min = ToVec(engine_->getProfileDimensions(
i, 0, nvinfer1::OptProfileSelector::kMAX));
auto max = ToVec(engine_->getProfileDimensions(
i, 0, nvinfer1::OptProfileSelector::kMIN));
auto name = std::string(engine_->getBindingName(i));
auto iter = shape_range_info_.find(name);
if (iter == shape_range_info_.end()) {
FDERROR << "There's no input named '" << name << "' in loaded model."
<< std::endl;
return false;
}
iter->second.Update(min);
iter->second.Update(max);
}
FDINFO << "Build TensorRT Engine from cache file: " << trt_engine_file
<< " with shape range information as below," << std::endl;
for (const auto& item : shape_range_info_) {
FDINFO << item.second << std::endl;
}
return true;
}
@@ -167,10 +117,11 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
<< std::endl;
return false;
}
option_ = option;
#ifdef ENABLE_PADDLE_FRONTEND
std::vector<paddle2onnx::CustomOp> custom_ops;
for (auto& item : option.custom_op_info_) {
for (auto& item : option_.custom_op_info_) {
paddle2onnx::CustomOp op;
std::strcpy(op.op_name, item.first.c_str());
std::strcpy(op.export_op_name, item.second.c_str());
@@ -187,7 +138,7 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
return false;
}
if (option.remove_multiclass_nms_) {
if (option_.remove_multiclass_nms_) {
char* new_model = nullptr;
int new_model_size = 0;
if (!paddle2onnx::RemoveMultiClassNMS(model_content_ptr, model_content_size,
@@ -222,7 +173,8 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
<< std::endl;
return false;
}
cudaSetDevice(option.gpu_id);
option_ = option;
cudaSetDevice(option_.gpu_id);
std::string onnx_content = "";
if (!from_memory_buffer) {
@@ -246,43 +198,94 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
outputs_order_.clear();
auto onnx_reader =
paddle2onnx::OnnxReader(onnx_content.c_str(), onnx_content.size());
for (int i = 0; i < onnx_reader.NumOutputs(); ++i) {
std::string name(
onnx_reader.output_names[i],
onnx_reader.output_names[i] + strlen(onnx_reader.output_names[i]));
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
std::string name(onnx_reader.outputs[i].name);
outputs_order_[name] = i;
}
if (!CheckDynamicShapeConfig(onnx_reader, option)) {
FDERROR << "TrtBackend::CheckDynamicShapeConfig failed." << std::endl;
return false;
}
if (option.serialize_file != "") {
std::ifstream fin(option.serialize_file, std::ios::binary | std::ios::in);
if (fin) {
FDINFO << "Detect serialized TensorRT Engine file in "
<< option.serialize_file << ", will load it directly."
<< std::endl;
fin.close();
return InitFromTrt(option.serialize_file, option);
shape_range_info_.clear();
inputs_desc_.clear();
outputs_desc_.clear();
inputs_desc_.resize(onnx_reader.num_inputs);
outputs_desc_.resize(onnx_reader.num_outputs);
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
std::string name(onnx_reader.inputs[i].name);
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
onnx_reader.inputs[i].shape +
onnx_reader.inputs[i].rank);
inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
auto info = ShapeRangeInfo(shape);
info.name = name;
auto iter_min = option.min_shape.find(name);
auto iter_max = option.max_shape.find(name);
auto iter_opt = option.opt_shape.find(name);
if (iter_min != option.min_shape.end()) {
info.min.assign(iter_min->second.begin(), iter_min->second.end());
info.max.assign(iter_max->second.begin(), iter_max->second.end());
info.opt.assign(iter_opt->second.begin(), iter_opt->second.end());
}
shape_range_info_.insert(std::make_pair(name, info));
}
if (!CreateTrtEngine(onnx_content, option)) {
return false;
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
std::string name(onnx_reader.outputs[i].name);
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
onnx_reader.outputs[i].shape +
onnx_reader.outputs[i].rank);
outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype =
ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype);
}
context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
engine_->createExecutionContext());
FDASSERT(cudaStreamCreate(&stream_) == 0,
"[ERROR] Error occurs while calling cudaStreamCreate().");
GetInputOutputInfo();
if (!CreateTrtEngineFromOnnx(onnx_content)) {
FDERROR << "Failed to create tensorrt engine." << std::endl;
return false;
}
initialized_ = true;
return true;
}
int TrtBackend::ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs) {
bool need_update_engine = false;
for (size_t i = 0; i < inputs.size(); ++i) {
auto iter = shape_range_info_.find(inputs[i].name);
if (iter == shape_range_info_.end()) {
FDERROR << "There's no input named '" << inputs[i].name
<< "' in loaded model." << std::endl;
}
if (iter->second.Update(inputs[i].shape) == 1) {
need_update_engine = true;
}
}
return need_update_engine;
}
bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != NumInputs()) {
FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
<< "." << std::endl;
return false;
}
if (ShapeRangeInfoUpdated(inputs)) {
// meet new shape output of predefined max/min shape
// rebuild the tensorrt engine
FDWARNING
<< "TensorRT engine will be rebuilt once shape range information "
"changed, this may take lots of time, you can set a proper shape "
"range before loading model to avoid rebuilding process. refer "
"https://github.com/PaddlePaddle/FastDeploy/docs/backends/"
"tensorrt.md for more details."
<< std::endl;
BuildTrtEngine();
}
AllocateBufferInDynamicShape(inputs, outputs);
std::vector<void*> input_binds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -316,12 +319,14 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
}
void TrtBackend::GetInputOutputInfo() {
std::vector<TrtValueInfo>().swap(inputs_desc_);
std::vector<TrtValueInfo>().swap(outputs_desc_);
inputs_desc_.clear();
outputs_desc_.clear();
auto num_binds = engine_->getNbBindings();
for (auto i = 0; i < num_binds; ++i) {
std::string name = std::string(engine_->getBindingName(i));
auto shape = toVec(engine_->getBindingDimensions(i));
auto shape = ToVec(engine_->getBindingDimensions(i));
auto dtype = engine_->getBindingDataType(i);
if (engine_->bindingIsInput(i)) {
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
@@ -355,8 +360,10 @@ void TrtBackend::AllocateBufferInDynamicShape(
// find the original index of output
auto iter = outputs_order_.find(outputs_desc_[i].name);
FDASSERT(iter != outputs_order_.end(),
"Cannot find output: %s of tensorrt network from the original model.", outputs_desc_[i].name.c_str());
FDASSERT(
iter != outputs_order_.end(),
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
(*outputs)[ori_idx].dtype = GetFDDataType(outputs_desc_[i].dtype);
(*outputs)[ori_idx].shape.assign(output_dims.d,
@@ -372,32 +379,15 @@ void TrtBackend::AllocateBufferInDynamicShape(
}
}
bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
const TrtBackendOption& option) {
const auto explicitBatch =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
builder_ = FDUniquePtr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(*FDTrtLogger::Get()));
if (!builder_) {
FDERROR << "Failed to call createInferBuilder()." << std::endl;
return false;
}
network_ = FDUniquePtr<nvinfer1::INetworkDefinition>(
builder_->createNetworkV2(explicitBatch));
if (!network_) {
FDERROR << "Failed to call createNetworkV2()." << std::endl;
return false;
}
auto config = FDUniquePtr<nvinfer1::IBuilderConfig>(
builder_->createBuilderConfig());
bool TrtBackend::BuildTrtEngine() {
auto config =
FDUniquePtr<nvinfer1::IBuilderConfig>(builder_->createBuilderConfig());
if (!config) {
FDERROR << "Failed to call createBuilderConfig()." << std::endl;
return false;
}
if (option.enable_fp16) {
if (option_.enable_fp16) {
if (!builder_->platformHasFastFp16()) {
FDWARNING << "Detected FP16 is not supported in the current GPU, "
"will use FP32 instead."
@@ -407,56 +397,52 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
}
}
parser_ = FDUniquePtr<nvonnxparser::IParser>(
nvonnxparser::createParser(*network_, *FDTrtLogger::Get()));
if (!parser_) {
FDERROR << "Failed to call createParser()." << std::endl;
return false;
}
if (!parser_->parse(onnx_model.data(), onnx_model.size())) {
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
return false;
}
FDINFO << "Start to building TensorRT Engine..." << std::endl;
bool fp16 = builder_->platformHasFastFp16();
builder_->setMaxBatchSize(option.max_batch_size);
config->setMaxWorkspaceSize(option.max_workspace_size);
if (option.max_shape.size() > 0) {
auto profile = builder_->createOptimizationProfile();
FDASSERT(option.max_shape.size() == option.min_shape.size() &&
option.min_shape.size() == option.opt_shape.size(),
"[TrtBackend] Size of max_shape/opt_shape/min_shape in "
"TrtBackendOption should keep same.");
for (const auto& item : option.min_shape) {
// set min shape
FDASSERT(profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kMIN,
ToDims(item.second)),
"[TrtBackend] Failed to set min_shape for input: %s in TrtBackend.", item.first.c_str());
// set optimization shape
auto iter = option.opt_shape.find(item.first);
FDASSERT(iter != option.opt_shape.end(),
"[TrtBackend] Cannot find input name: %s in TrtBackendOption::opt_shape.", item.first.c_str());
FDASSERT(profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kOPT,
ToDims(iter->second)),
"[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.", item.first.c_str());
// set max shape
iter = option.max_shape.find(item.first);
FDASSERT(iter != option.max_shape.end(),
"[TrtBackend] Cannot find input name: %s in TrtBackendOption::max_shape.", item.first);
FDASSERT(profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kMAX,
ToDims(iter->second)),
"[TrtBackend] Failed to set max_shape for input: %s in TrtBackend.", item.first);
}
config->addOptimizationProfile(profile);
if (context_) {
context_.reset();
engine_.reset();
}
builder_->setMaxBatchSize(option_.max_batch_size);
config->setMaxWorkspaceSize(option_.max_workspace_size);
auto profile = builder_->createOptimizationProfile();
for (const auto& item : shape_range_info_) {
FDASSERT(
profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kMIN,
ToDims(item.second.min)),
"[TrtBackend] Failed to set min_shape for input: %s in TrtBackend.",
item.first.c_str());
FDASSERT(
profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kMAX,
ToDims(item.second.max)),
"[TrtBackend] Failed to set max_shape for input: %s in TrtBackend.",
item.first.c_str());
if (item.second.opt.size() == 0) {
FDASSERT(
profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kOPT,
ToDims(item.second.max)),
"[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.",
item.first.c_str());
} else {
FDASSERT(
item.second.opt.size() == item.second.shape.size(),
"Require the dimension of opt in shape range information equal to "
"dimension of input: %s in this model, but now it's %zu != %zu.",
item.first.c_str(), item.second.opt.size(), item.second.shape.size());
FDASSERT(
profile->setDimensions(item.first.c_str(),
nvinfer1::OptProfileSelector::kOPT,
ToDims(item.second.opt)),
"[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.",
item.first.c_str());
}
}
config->addOptimizationProfile(profile);
FDUniquePtr<nvinfer1::IHostMemory> plan{
builder_->buildSerializedNetwork(*network_, *config)};
if (!plan) {
@@ -479,20 +465,24 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
return false;
}
context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
engine_->createExecutionContext());
GetInputOutputInfo();
FDINFO << "TensorRT Engine is built succussfully." << std::endl;
if (option.serialize_file != "") {
FDINFO << "Serialize TensorRTEngine to local file " << option.serialize_file
<< "." << std::endl;
std::ofstream engine_file(option.serialize_file.c_str());
if (option_.serialize_file != "") {
FDINFO << "Serialize TensorRTEngine to local file "
<< option_.serialize_file << "." << std::endl;
std::ofstream engine_file(option_.serialize_file.c_str());
if (!engine_file) {
FDERROR << "Failed to open " << option.serialize_file << " to write."
FDERROR << "Failed to open " << option_.serialize_file << " to write."
<< std::endl;
return false;
}
engine_file.write(static_cast<char*>(plan->data()), plan->size());
engine_file.close();
FDINFO << "TensorRTEngine is serialized to local file "
<< option.serialize_file
<< option_.serialize_file
<< ", we can load this model from the seralized engine "
"directly next time."
<< std::endl;
@@ -500,8 +490,81 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
return true;
}
bool TrtBackend::CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer) {
const auto explicitBatch =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
builder_ = FDUniquePtr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(*FDTrtLogger::Get()));
if (!builder_) {
FDERROR << "Failed to call createInferBuilder()." << std::endl;
return false;
}
network_ = FDUniquePtr<nvinfer1::INetworkDefinition>(
builder_->createNetworkV2(explicitBatch));
if (!network_) {
FDERROR << "Failed to call createNetworkV2()." << std::endl;
return false;
}
parser_ = FDUniquePtr<nvonnxparser::IParser>(
nvonnxparser::createParser(*network_, *FDTrtLogger::Get()));
if (!parser_) {
FDERROR << "Failed to call createParser()." << std::endl;
return false;
}
if (!parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size())) {
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
return false;
}
if (option_.serialize_file != "") {
std::ifstream fin(option_.serialize_file, std::ios::binary | std::ios::in);
if (fin) {
FDINFO << "Detect serialized TensorRT Engine file in "
<< option_.serialize_file << ", will load it directly."
<< std::endl;
fin.close();
// clear memory buffer of the temporary member
std::string().swap(onnx_model_buffer_);
return LoadTrtCache(option_.serialize_file);
}
}
if (!CanBuildEngine(shape_range_info_)) {
onnx_model_buffer_ = onnx_model_buffer;
FDWARNING << "Cannot build engine right now, because there's dynamic input "
"shape exists, list as below,"
<< std::endl;
for (int i = 0; i < NumInputs(); ++i) {
FDWARNING << "Input " << i << ": " << GetInputInfo(i) << std::endl;
}
FDWARNING
<< "FastDeploy will build the engine while inference with input data, "
"and will also collect the input shape range information. You "
"should be noticed that FastDeploy will rebuild the engine while "
"new input shape is out of the collected shape range, this may "
"bring some time consuming problem, refer "
"https://github.com/PaddlePaddle/FastDeploy/docs/backends/"
"tensorrt.md for more details."
<< std::endl;
initialized_ = true;
return true;
}
if (!BuildTrtEngine()) {
FDERROR << "Failed to build tensorrt engine." << std::endl;
}
// clear memory buffer of the temporary member
std::string().swap(onnx_model_buffer_);
return true;
}
TensorInfo TrtBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs());
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs());
TensorInfo info;
info.name = inputs_desc_[index].name;
info.shape.assign(inputs_desc_[index].shape.begin(),
@@ -512,7 +575,8 @@ TensorInfo TrtBackend::GetInputInfo(int index) {
TensorInfo TrtBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs: %d.", index, NumOutputs());
"The index: %d should less than the number of outputs: %d.", index,
NumOutputs());
TensorInfo info;
info.name = outputs_desc_[index].name;
info.shape.assign(outputs_desc_[index].shape.begin(),
@@ -520,4 +584,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
return info;
}
} // namespace fastdeploy
} // namespace fastdeploy