mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
add 'GetOutputInfos' and 'GetInputInfos' interface (#232)
add GetOutputInfos GetInputInfos
This commit is contained in:
@@ -589,6 +589,14 @@ TensorInfo TrtBackend::GetInputInfo(int index) {
|
||||
return info;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> TrtBackend::GetInputInfos() {
|
||||
std::vector<TensorInfo> infos;
|
||||
for (auto i = 0; i < inputs_desc_.size(); i++) {
|
||||
infos.emplace_back(GetInputInfo(i));
|
||||
}
|
||||
return infos;
|
||||
}
|
||||
|
||||
TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
FDASSERT(index < NumOutputs(),
|
||||
"The index: %d should less than the number of outputs: %d.", index,
|
||||
@@ -600,4 +608,13 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
|
||||
return info;
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
|
||||
std::vector<TensorInfo> infos;
|
||||
for (auto i = 0; i < outputs_desc_.size(); i++) {
|
||||
infos.emplace_back(GetOutputInfo(i));
|
||||
}
|
||||
return infos;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -72,6 +72,8 @@ class TrtBackend : public BaseBackend {
|
||||
int NumOutputs() const { return outputs_desc_.size(); }
|
||||
TensorInfo GetInputInfo(int index);
|
||||
TensorInfo GetOutputInfo(int index);
|
||||
std::vector<TensorInfo> GetInputInfos() override;
|
||||
std::vector<TensorInfo> GetOutputInfos() override;
|
||||
|
||||
~TrtBackend() {
|
||||
if (parser_) {
|
||||
|
Reference in New Issue
Block a user