Files
FastDeploy/paddle2onnx/legacy/graph/graph.py
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

288 lines
9.3 KiB
Python
Executable File

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import os
import copy
import six
import collections
from paddle2onnx.legacy.constant import NodeDomain
class Node(object):
def __init__(self,
op_type,
inputs,
outputs,
attrs,
layer_name,
domain=NodeDomain.RAW):
self.domain = domain
self.type = op_type
self.attrs = attrs
self.layer_name = layer_name
self.set_inputs(inputs)
self.set_outputs(outputs)
def __hash__(self):
return hash(self.layer_name)
def __eq__(self, other):
if self.layer_name == other.layer_name:
return True
return False
def __str__(self):
node_str = ''
attrs = ''
for key, value in self.attrs.items():
attrs += ', ' + key + '=' + str(value)
node_str += " {} = {}::{}(inputs={}{}) \n".format(
self.outputs, self.domain, self.type, self.inputs, attrs)
return node_str
def input(self, idx=None):
if idx is None:
return self.inputs
return self.inputs[idx]
def output(self, idx=None):
if idx is None:
return self.outputs
return self.outputs[idx]
def attr(self, name):
if name in self.attrs:
return self.attrs[name]
return None
def set_inputs(self, inputs):
if isinstance(inputs, list):
self.inputs = [
ipt.layer_name if isinstance(ipt, Node) else ipt
for ipt in inputs
]
elif isinstance(inputs, six.string_types):
self.inputs = [inputs]
elif isinstance(inputs, Node):
self.inputs = [inputs.layer_name]
else:
raise TypeError(
'Inputs of node must be type: list, Node, or String but got {}'.
format(type(inputs)))
def set_outputs(self, outputs):
if isinstance(outputs, list):
self.outputs = [
opt.layer_name if isinstance(opt, Node) else opt
for opt in outputs
]
elif isinstance(outputs, six.string_types):
self.outputs = [outputs]
elif isinstance(outputs, Node):
self.outputs = [outputs.layer_name]
else:
raise TypeError(
'Outputs of node must be type: list, Node, or String but got {}'.
format(type(outputs)))
class Graph(object):
def __init__(self):
self.parameters = {}
self.node_map = collections.OrderedDict()
self.input_nodes = list()
self.output_nodes = list()
self.op_type_count = dict()
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
if self.id == other.id:
return True
return False
def __str__(self):
graph_str = 'graph { \n'
for node in self.input_nodes:
graph_str += " input: {} \n".format(node.layer_name)
for node in self.output_nodes:
graph_str += " output: {} \n \n".format(node.layer_name)
for name, node in self.node_map.items():
graph_str += node.__str__()
graph_str += ' }'
return graph_str
def set_output_nodes(self, node_list):
if isinstance(node_list, list):
self.output_nodes = node_list
else:
raise TypeError(
'output_nodes of Graph must be type: list, but got {}'.format(
type(node_list)))
def set_node_map(self, node_map):
if isinstance(node_map, dict):
self.node_map = node_map
self.generate_topo_sort()
else:
raise TypeError('node_map of Graph must be type: list, but got {}'.
format(type(node_map)))
def set_input_nodes(self, node_list):
if isinstance(node_list, list):
self.input_nodes = node_list
else:
raise TypeError(
'input_nodes of Graph must be type: list, but got {}'.format(
type(node_list)))
def set_parameters(self, parameters):
if isinstance(parameters, dict):
self.parameters = parameters
else:
raise TypeError(
'parameters of Graph must be type: dict, but got {}'.format(
type(parameters)))
def generate_node_name(self, op_type):
if op_type in self.op_type_count:
self.op_type_count[op_type] += 1
else:
self.op_type_count[op_type] = 1
# layer_name need follow https://github.com/onnx/onnx/blob/master/docs/OpConventions.md
layer_name = op_type + '_' + str(self.op_type_count[op_type] - 1)
return layer_name
def insert_node(self, node):
if node.type not in ['feed', 'fetch']:
self.node_map[node.layer_name] = node
def make_node(self,
op_type,
inputs=None,
outputs=None,
attrs=None,
layer_name=None,
domain=None,
**kw):
if layer_name is None:
layer_name = self.generate_node_name(op_type)
if attrs is None:
attrs = kw
attrs.update(kw)
if inputs is None:
inputs = []
if outputs is None:
outputs = [layer_name]
node = Node(op_type, layer_name, inputs, outputs, attrs, domain)
self.insert_node(node)
return node
def update_node(self,
node,
op_type=None,
inputs=None,
outputs=None,
attrs=None,
block=None,
move_to_end=True,
domain=None,
**kw):
if op_type is not None:
node.type = op_type
if inputs is not None:
node.set_inputs(inputs)
if outputs is not None:
node.set_outputs(outputs)
if attrs is None:
attrs = kw
attrs.update(kw)
node.attrs = attrs
if domain is not None:
node.domain = domain
if move_to_end:
self.node_map.pop(node.layer_name)
self.node_map[node.layer_name] = node
return node
def get_node(self, name, copy=False):
if name not in self.node_map:
raise TypeError('Node with name:{} not in graph'.format(name))
if copy:
node = copy.copy(self.node_map[name])
else:
node = self.node_map[name]
return node
def remove_node_by_name(self, name):
if name in self.node_map:
node = self.node_map.pop(name)
return node
raise TypeError('Node with name:{} not in graph'.format(name))
def remove_node(self, node):
if isinstance(node, Node):
node = self.remove_node_by_name(node.layer_name)
return node
else:
node = self.remove_node_by_name(node)
return node
def get_output_nodes_of_node(self, node):
if node in self.edge_map:
return self.edge_map[node]
elif self.get_node(node.layer_name, copy=False):
return []
else:
raise KeyError('Node with layer_name {} not in graph.egde_map'.
format(node.layer_name))
def get_adjacency_map(self):
adjacency_map = {}
for layer_name, current_node in self.node_map.items():
inputs = current_node.inputs
for ipt in inputs:
for layer_name, node in self.node_map.items():
if current_node == node:
continue
outputs = node.outputs
if ipt in outputs:
if node not in adjacency_map:
adjacency_map[node] = set([current_node])
else:
adjacency_map[node].add(current_node)
return adjacency_map
def get_topo_sort_list(self):
topo_sort_list = list()
adjacency_map = self.get_adjacency_map()
for layer_name, node in self.node_map.items():
if node not in adjacency_map:
topo_sort_list.append(node)
idx = 0
while idx < len(topo_sort_list):
current_node = topo_sort_list[idx]
for input_node, output_nodes in adjacency_map.items():
if current_node in output_nodes:
adjacency_map[input_node].remove(current_node)
if len(adjacency_map[input_node]) == 0:
topo_sort_list.append(input_node)
idx += 1
return topo_sort_list[::-1]