Files
FastDeploy/paddle2onnx/legacy/op_mapper/op_mapper.py
Jason 6343b0db47 [Build] Support build with source code of Paddle2ONNX (#1559)
* Add notes for tensors

* Optimize some apis

* move some warnings

* Support build with Paddle2ONNX

* Add protobuf support

* Fix compile on mac

* add clearn package script

* Add paddle2onnx code

* remove submodule

* Add onnx ocde

* remove softlink

* add onnx code

* fix error

* Add cmake file

* fix patchelf

* update paddle2onnx

* Delete .gitmodules

---------

Co-authored-by: PaddleCI <paddle_ci@example.com>
Co-authored-by: pangyoki <pangyoki@126.com>
Co-authored-by: jiangjiajun <jiangjiajun@baidu.lcom>
2023-03-17 10:03:22 +08:00

306 lines
13 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import inspect
import six
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import layers
from paddle2onnx.legacy.graph import graph_helper, PaddleGraph
from paddle2onnx.utils import logging
from paddle2onnx.legacy.constant.op_mapping_status import *
REGISTER_CUSTOM_PADDLE_OP = {}
def get_max_support_version(versions, opset_version):
max_version = -1
for vs in sorted(versions):
if vs <= opset_version:
max_version = vs
return max_version
def register_op_mapper(paddle_op, mapper_obj):
paddle_op_list = []
if isinstance(paddle_op, six.string_types):
paddle_op_list.append(paddle_op)
elif isinstance(paddle_op, list):
paddle_op_list = paddle_op
else:
raise ValueError('paddle_op must be List or string, but got type {}.'.
format(type(paddle_op)))
if not isinstance(mapper_obj, six.class_types):
raise ValueError('mapper_obj must be Class, but got type {}.'.format(
type(mapper_obj)))
valid_register_func = 0
for k, v in inspect.getmembers(mapper_obj, inspect.ismethod):
if k.startswith("opset_"):
version = int(k.replace("opset_", ""))
if version > 13 or version < 1:
raise Exception(
'the specific method of operator mapper must be named opset_[number](1<=number<=13), such as opset_9, but got {}.'.
format(k))
valid_register_func += 1
if valid_register_func == 0:
raise Exception(
'the specific method of operator mapper must be classmethod, which named opset_[number](1<=number<=13), such as opset_9, but none achieved.'
)
mapper = OpMapper(paddle_op_list)
mapper(mapper_obj)
class OpMapper(object):
OPSETS = {}
REGISTER_CUSTOM_PADDLE_OP = {}
def __init__(self, paddle_op, **kwargs):
if not isinstance(paddle_op, list):
paddle_op = [paddle_op]
self.paddle_op = paddle_op
self.kwargs = kwargs
def __call__(self, cls):
for k, v in inspect.getmembers(cls, inspect.ismethod):
if k.startswith("opset_"):
version = int(k.replace("opset_", ""))
for op in self.paddle_op:
if op not in OpMapper.OPSETS:
OpMapper.OPSETS[op] = {}
opset_dict = OpMapper.OPSETS[op]
opset_dict[version] = (v, self.kwargs)
@staticmethod
def mapping(graph, node, operator_export_type="ONNX"):
try:
if node.type in OpMapper.REGISTER_CUSTOM_PADDLE_OP:
if operator_export_type in ["PaddleFallback"]:
opsets = OpMapper.OPSETS[node.type]
versions = list(opsets.keys())
convert_version = get_max_support_version(
versions, graph.opset_version)
mapper_func, kw = opsets[convert_version]
mapper_func(graph, node, **kw)
else:
custom_paddle_op = OpMapper.REGISTER_CUSTOM_PADDLE_OP[
node.type](node)
custom_paddle_graph, output_results = custom_paddle_op.get_paddle_graph(
)
OpMapper.check_support_status(custom_paddle_graph.node_map,
graph.opset_version)
graph.build_op_nodes(custom_paddle_graph.node_map)
node_output_results = dict()
for k in node.output_names:
custom_outs = output_results[k]
node_outs = node.output(k)
assert len(custom_outs) == len(
node_outs
), "Length of custom implementation operator's outputs is not same with the length of original operator's outputs."
for i in range(len(custom_outs)):
graph.make_node(
"Identity",
inputs=[custom_outs[i]],
outputs=[node_outs[i]])
else:
opsets = OpMapper.OPSETS[node.type]
versions = list(opsets.keys())
convert_version = get_max_support_version(versions,
graph.opset_version)
mapper_func, kw = opsets[convert_version]
mapper_func(graph, node, **kw)
except Exception as e:
raise Exception(
"Error happened when mapping node ['{}'] to onnx, which op_type is '{}' with inputs: {} and outputs: {}, specific error: ".
format(node.layer_name, node.type, node.inputs,
node.outputs) + str(e))
@staticmethod
def get_recommend_opset_version(node_map, opset_version):
recommend_opset_version = OpMapper.check_support_status(
node_map, opset_version, True)
for name, node in list(node_map.items()):
if node.type in OpMapper.REGISTER_CUSTOM_PADDLE_OP: #如果是custom的op获取custom的推荐op
custom_paddle_op = OpMapper.REGISTER_CUSTOM_PADDLE_OP[
node.type](node)
custom_paddle_graph, output_results = custom_paddle_op.get_paddle_graph(
)
custom_recommend_opset_version = OpMapper.check_support_status(
custom_paddle_graph.node_map, opset_version, True)
recommend_opset_version = max(recommend_opset_version,
custom_recommend_opset_version)
if opset_version != recommend_opset_version:
warning_info = "\n======================\n"
warning_info += "\nFor a successful conversion, set the recommended opset version : {}\n".format(
recommend_opset_version)
warning_info += "\n======================\n"
logging.warning(warning_info)
return recommend_opset_version
@staticmethod
def check_support_status(node_map, opset_version, for_check=False):
op_mapping_status = {
OP_MAPPING_NO_REGISTER: [],
OP_MAPPING_NO_VERSION: [],
}
for name, node in list(node_map.items()):
if node.type in OpMapper.REGISTER_CUSTOM_PADDLE_OP:
continue
if node.type not in OpMapper.OPSETS:
op_mapping_status[OP_MAPPING_NO_REGISTER].append(node)
else:
opsets = OpMapper.OPSETS[node.type]
versions = list(opsets.keys())
convert_version = get_max_support_version(versions,
opset_version)
if convert_version == -1:
op_mapping_status[OP_MAPPING_NO_VERSION].append(node)
if len(op_mapping_status[OP_MAPPING_NO_REGISTER]) > 0:
unsupported_op_types = set([
node.type for node in op_mapping_status[OP_MAPPING_NO_REGISTER]
])
error_info = "\nThere's {} ops are not supported yet\n".format(
len(unsupported_op_types))
for op_type in unsupported_op_types:
error_info += "=========== {} ===========\n".format(op_type)
raise NotImplementedError(error_info)
if len(op_mapping_status[OP_MAPPING_NO_VERSION]) > 0:
unsupported_op_types = set([
node.type for node in op_mapping_status[OP_MAPPING_NO_VERSION]
])
recommend_opset_version = -1
for op_type in unsupported_op_types:
opsets = OpMapper.OPSETS[op_type]
if min(opsets.keys()) > recommend_opset_version:
recommend_opset_version = min(opsets.keys())
warning_info = "\nThere are {} ops that are not supported in opset version {}, please set opset version >= {}.\n".format(
len(unsupported_op_types), opset_version,
recommend_opset_version)
for op_type in unsupported_op_types:
warning_info += "=========== {} ===========\n".format(op_type)
if for_check:
logging.warning(warning_info)
return recommend_opset_version
raise NotImplementedError(warning_info)
return opset_version
class CustomPaddleOp(object):
CREATE_TIMES = {}
def __init__(self, node):
self.main_program = paddle.static.Program()
self.startup_program = paddle.static.Program()
self.inputs = self.create_place_holder(node)
self.node = node
def generate_scope_name(self, node):
if node.type in CustomPaddleOp.CREATE_TIMES:
CustomPaddleOp.CREATE_TIMES[node.type] += 1
else:
CustomPaddleOp.CREATE_TIMES[node.type] = 1
scope_prefix = node.type + str(CustomPaddleOp.CREATE_TIMES[node.type] -
1) + '_'
return scope_prefix
def create_place_holder(self, node):
place_holders = {}
with paddle.static.program_guard(self.main_program,
self.startup_program):
for arg_name, idxs in node.inputs.items():
place_holders[arg_name] = []
for idx in range(len(idxs)):
shape = node.input_shape(arg_name, idx)
dtype = node.input_dtype(arg_name, idx)
name = node.input(arg_name, idx)
data = paddle.static.data(
name=name, shape=shape, dtype=dtype)
place_holders[arg_name].append(data)
return place_holders
def input(self, name, idx=None):
if name not in self.inputs:
return None
if idx is None:
return self.inputs[name]
if len(self.inputs[name]) <= idx:
return None
return self.inputs[name][idx]
def get_paddle_graph(self):
scope_prefix = self.generate_scope_name(self.node)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(self.main_program,
self.startup_program):
with paddle.utils.unique_name.guard(scope_prefix):
res = self.forward()
feed_var_names = [
var.name for vars in self.inputs.values()
for var in vars
]
fetch_vars = [var for vars in res.values() for var in vars]
inference_program = graph_helper.get_program(
self.main_program, feed_var_names, fetch_vars)
paddle_graph = PaddleGraph.build_from_program(
inference_program,
feed_var_names,
fetch_vars,
scope=scope)
output_results = dict()
for arg_name, outs in res.items():
output_results[arg_name] = [out.name for out in outs]
return paddle_graph, output_results
def register_custom_paddle_op(paddle_op, custom_op):
paddle_op_list = []
if isinstance(paddle_op, six.string_types):
paddle_op_list.append(paddle_op)
elif isinstance(paddle_op, list):
paddle_op_list = paddle_op
else:
raise ValueError("paddle_op' must be List or string, but got type {}.".
format(type(paddle_op)))
if not isinstance(custom_op, six.class_types):
raise ValueError("'custom_op' must be Class, but got type {}.".format(
type(custom_op)))
forward = getattr(custom_op, "forward", None)
if not callable(forward):
raise Exception(
"Custom paddle operators must be implemented in function named 'forward'."
)
for op in paddle_op_list:
if op not in OpMapper.REGISTER_CUSTOM_PADDLE_OP:
OpMapper.REGISTER_CUSTOM_PADDLE_OP[op] = custom_op