Files
FastDeploy/third_party/optimizer/onnxoptimizer/passes/fuse_consecutive_transposes.h
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

80 lines
2.6 KiB
C++

/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include "onnxoptimizer/pass.h"
namespace ONNX_NAMESPACE {
namespace optimization {
struct FuseConsecutiveTransposes final : public PredicateBasedPass {
explicit FuseConsecutiveTransposes()
: PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete,
PassOptimizationType::Compute) {}
std::string getPassName() const override {
return "fuse_consecutive_transposes";
}
// returns a vector `ret` such that transposing by `ret` is equivalent
// to transposing by `t1` and then by `t2`
std::vector<int64_t> compose_transposes(const std::vector<int64_t>& t1,
const std::vector<int64_t>& t2) {
ONNX_ASSERT(t1.size() == t2.size());
std::vector<int64_t> ret;
ret.reserve(t1.size());
for (size_t i = 0; i < t1.size(); i++) {
ONNX_ASSERT(t2[i] < static_cast<int64_t>(t1.size()));
ONNX_ASSERT(t1[static_cast<size_t>(t2[i])] <
static_cast<int64_t>(t1.size()));
ret.push_back(t1[static_cast<size_t>(t2[i])]);
}
return ret;
}
bool patternMatchPredicate(Node* node) override {
return node->kind() == kTranspose &&
node->input()->node()->kind() == kTranspose;
}
bool runTransform(Node* n, Graph&,
NodeDestroyType& destroy_current) override {
auto origInput = n->input();
if (!n->hasAttribute(kperm) && !origInput->node()->hasAttribute(kperm)) {
// One special case (two consecutive transposes with no perm,
// since we do not have the shape information here, we have
// to eliminate two transpose together.
if (n->output()->has_sizes()) {
origInput->node()->input()->setSizes(n->output()->sizes());
}
const bool replacing_success =
tryReplacingAllUsesWith(n, origInput->node()->input()->node());
if (!replacing_success) {
return false;
}
destroy_current = NodeDestroyType::DestroyTwo;
return true;
}
if (!n->hasAttribute(kperm) || !origInput->node()->hasAttribute(kperm)) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
n->is_(kperm,
compose_transposes(origInput->node()->is(kperm), n->is(kperm)));
n->replaceInput(0, origInput->node()->input());
if (origInput->uses().size() == 0) {
origInput->node()->destroy();
}
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
};
} // namespace optimization
} // namespace ONNX_NAMESPACE