Files
FastDeploy/third_party/optimizer/onnxoptimizer/passes/fuse_consecutive_log_softmax.h
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

54 lines
1.7 KiB
C++

/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include "onnxoptimizer/pass.h"
namespace ONNX_NAMESPACE {
namespace optimization {
struct FuseConsecutiveLogSoftmax final : public PredicateBasedPass {
explicit FuseConsecutiveLogSoftmax()
: PredicateBasedPass(PassType::Fuse, PassEfficiency::Complete,
PassOptimizationType::Compute) {}
std::string getPassName() const override {
return "fuse_consecutive_log_softmax";
}
bool patternMatchPredicate(Node* node) override {
return node->kind() == kLog && node->input()->node()->kind() == kSoftmax &&
node->input()->uses().size() == 1;
}
bool runTransform(Node* log_node, Graph& graph,
NodeDestroyType& destroy_current) override {
Value* log_node_output = log_node->output();
Node* softmax_node = log_node->inputs()[0]->node();
Node* log_softmax_node = graph.create(kLogSoftmax, 1);
// log_softmax_node construction
log_softmax_node->i_(kaxis, softmax_node->i(kaxis));
log_softmax_node->addInput(softmax_node->input());
log_softmax_node->insertBefore(softmax_node);
log_softmax_node->output()->setSizes(log_node_output->sizes());
log_softmax_node->output()->setElemType(log_node_output->elemType());
const bool replacing_success =
tryReplacingAllUsesWith(log_node, log_softmax_node);
if (!replacing_success) {
return false;
}
log_node->removeAllInputs();
destroy_current = NodeDestroyType::DestroyTwo;
return true;
}
};
} // namespace optimization
} // namespace ONNX_NAMESPACE