Files
FastDeploy/third_party/optimizer/onnxoptimizer/pass.h
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

260 lines
8.8 KiB
C++

/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <string>
#include "onnx/common/ir.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
namespace optimization {
// Base struct representing result of a pass.
struct PostPassAnalysis {
virtual ~PostPassAnalysis() = default;
};
// Enum that represents the type of optimization it is.
enum PassType {
// Class of optimizations that fuses operations.
Fuse = 0,
// Class of optimizations that removes useless operations.
Nop = 1,
// Class of optimizations that includes some form of seperation.
Separate = 2,
// Immutable pass, also sometimes referred to as an analysis pass.
Immutable = 3,
// Other type of pass.
Other = 4
};
// Enum that represents the return type of the analysis.
enum PassAnalysisType {
// An empty analysis is returned. Most likely will return PostPassAnalysis.
Empty = 0,
// A count based analysis is returned. Most likely of type
// CountBasedPassAnalysis
CountBased = 1
};
enum PassEfficiency {
// A partially efficient optimization pass cannot guarantee that running two
// consecutive passes
// will return the same result as running a single pass.
Partial = 0,
// A completely efficient optimization guarantees that running two consecutive
// passes is equivalent
// to running a single pass.
Complete = 1
};
// Describes what the optimization pass is attempting to optimize.
enum PassOptimizationType {
// Is not optimizing anything. Most likely will be used in an immutable pass.
None = 0,
// Optimizes for compute.
Compute = 1,
// Optimizes for memory.
Memory = 2,
// Optimizes for both compute and memory.
ComputeMemory = 3,
// Optimizes for stability (e.g. log-sum-exp trick).
Stability = 4
};
enum NodeDestroyType {
// Does not destroy node
DestroyZero = 0,
// Equivalent to calling it.destroyCurrent() once.
DestroyOne = 1,
// Equivalent to calling it.destroyCurrent() twice.
DestroyTwo = 2
};
// Base class for all optimizations within ONNX. A pass must contain the
// annotations described above. Furthermore each pass is given the ability to
// initialize and finalize it's pass. Each pass must have a unique name that
// pass managers/registry will use as identification. Finally the pass
// implements runPass which completes the pass inplace.
class Pass {
PassType pass_type;
PassEfficiency pass_efficiency;
PassOptimizationType pass_optimization_type;
public:
Pass(PassType pass_type, PassEfficiency pass_efficiency,
PassOptimizationType pass_optimization_type);
virtual ~Pass();
PassType getPassType() const {
return this->pass_type;
}
PassEfficiency getPassEfficiency() const {
return this->pass_efficiency;
}
PassOptimizationType getPassOptimizationType() const {
return this->pass_optimization_type;
}
virtual PassAnalysisType getPassAnalysisType() const = 0;
virtual std::string getPassName() const = 0;
virtual bool initializePass(Graph &) {
return false;
}
virtual bool finalizePass(Graph &) {
return false;
}
virtual std::shared_ptr<PostPassAnalysis> runPass(Graph &graph) = 0;
protected:
// Iterates through the elements in the graph and counts the number of times
// the transform is successfully run.
unsigned int DescendOnGraphAttributesAndCount(
Node *n, std::function<unsigned int(Graph &)> fn);
// A more general version of the function above that doesn't constrain the
// return type of fn.
void DescendOnGraphAttributesUnconstrained(Node *n,
std::function<void(Graph &)> fn);
};
class ImmutablePass : Pass {
public:
explicit ImmutablePass()
: Pass(PassType::Immutable, PassEfficiency::Complete,
PassOptimizationType::None) {}
~ImmutablePass() override;
};
// Pass Analysis done after a predicate based pass.
struct CountBasedPassAnalysis : PostPassAnalysis {
// Have to use raw pointer here. The idea is that the pass will pass <this> as
// a parameter to the constructor. We could use std::enable_shared_from_this
// but this complicates the memory model. Also since all passes come from
// GlobalPassRegistry which already utilizes smart pointers we don't have to
// worry about memory leaks from passes.
Pass *pass;
unsigned int num_positive_transforms;
bool initialization_done;
bool finalization_done;
public:
explicit CountBasedPassAnalysis(Pass *pass,
unsigned int num_positive_transforms,
bool initialization_done,
bool finalization_done);
bool graphChanged() {
return this->num_positive_transforms > 0;
}
bool numSucceededTransforms() {
return this->num_positive_transforms;
}
// Whether or not a repeated application of the pass might be useful.
bool fixedPointOptimizationNeeded() {
return this->graphChanged() &&
pass->getPassEfficiency() == PassEfficiency::Partial;
}
};
// A pass that is based on pattern matching. The majority of passes will
// implement this pass. In order for the pass to work the patternMatchPredicate
// function must be implemented witch matches a subgraph to the respective
// optimization pass. Lastly the runTransform method must also be implemented
// which simply implements the pass on any node which passes
// patternMatchPredicate.
class PredicateBasedPass : public Pass {
public:
explicit PredicateBasedPass(PassType pass_type,
PassEfficiency pass_efficiency,
PassOptimizationType pass_optimization_type)
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
~PredicateBasedPass() override;
virtual bool patternMatchPredicate(Node *node) = 0;
// Run transform is given the current node in the iterator, a reference to the
// current graph as well as a reference describing how to treat the current
// node in the iterator post transform. Run transform is then responsible for
// running the actual transform as well as describing how to treat the
// iterator node. By default the current node will not call destroy. Do not
// internally delete node instead set the correct destroy_current type.
virtual bool runTransform(Node *node, Graph &graph,
NodeDestroyType &destroy_current) = 0;
std::shared_ptr<PostPassAnalysis> runPass(Graph &graph) override;
PassAnalysisType getPassAnalysisType() const override;
static int getOpsetVersion(const Graph &g) {
// this hack is due to `opset_versions_mutable` doesn't have a const version
Graph &mut_g = const_cast<Graph &>(g);
for (const OpSetID &opset : mut_g.opset_versions_mutable()) {
if (opset.domain() == "") {
return opset.version();
}
}
return 0;
}
private:
unsigned int _runPassInternal(Graph &graph);
};
// The most general pass which allows the user to run a pass given only a graph.
class FullGraphBasedPass : public Pass {
public:
explicit FullGraphBasedPass(PassType pass_type,
PassEfficiency pass_efficiency,
PassOptimizationType pass_optimization_type)
: Pass(pass_type, pass_efficiency, pass_optimization_type) {}
~FullGraphBasedPass() override;
};
// If both value1 and value2 are input/output,
// we cannot replace one with another and also keeping the
// input/output names unchanged.
inline bool areTwoValuesBothInputOrOutput(const Value *value1,
const Value *value2) {
const auto IsInputOrOutput = [](const Value *value) {
const auto graph = value->owningGraph();
const bool is_output =
std::find(graph->outputs().rbegin(), graph->outputs().rend(), value) !=
graph->outputs().rend();
const bool is_input =
value->node()->kind() == kCaptured ||
std::find(graph->inputs().rbegin(), graph->inputs().rend(), value) !=
graph->inputs().rend();
return is_output || is_input;
};
return IsInputOrOutput(value1) && IsInputOrOutput(value2);
}
inline bool tryReplacingAllUsesWith(Value *oldValue, Value *newValue) {
if (areTwoValuesBothInputOrOutput(oldValue, newValue)) {
return false;
}
oldValue->replaceAllUsesWith(newValue);
return true;
}
inline bool tryReplacingAllUsesWith(Node *oldNode, Node *newNode) {
ONNX_ASSERT(oldNode->outputs().size() == newNode->outputs().size());
size_t nOutputs = oldNode->outputs().size();
for (size_t i = 0; i < nOutputs; i++) {
const auto *oldValue = oldNode->outputs()[i];
const auto *newValue = newNode->outputs()[i];
if (areTwoValuesBothInputOrOutput(oldValue, newValue)) {
return false;
}
}
oldNode->replaceAllUsesWith(newNode);
return true;
}
} // namespace optimization
} // namespace ONNX_NAMESPACE