mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
add 'GetOutputInfos' and 'GetInputInfos' interface (#232)
add GetOutputInfos GetInputInfos
This commit is contained in:
@@ -589,6 +589,14 @@ TensorInfo TrtBackend::GetInputInfo(int index) {
|
||||
return info;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> TrtBackend::GetInputInfos() {
|
||||
std::vector<TensorInfo> infos;
|
||||
for (auto i = 0; i < inputs_desc_.size(); i++) {
|
||||
infos.emplace_back(GetInputInfo(i));
|
||||
}
|
||||
return infos;
|
||||
}
|
||||
|
||||
TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
FDASSERT(index < NumOutputs(),
|
||||
"The index: %d should less than the number of outputs: %d.", index,
|
||||
@@ -600,4 +608,13 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
|
||||
return info;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
|
||||
std::vector<TensorInfo> infos;
|
||||
for (auto i = 0; i < outputs_desc_.size(); i++) {
|
||||
infos.emplace_back(GetOutputInfo(i));
|
||||
}
|
||||
return infos;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
Reference in New Issue
Block a user