Files
FastDeploy/examples/application/js/converter/fuseOps.py
chenqianhe f2619b0546 [Other] Refactor js submodule (#415)
* Refactor js submodule

* Remove change-log

* Update ocr module

* Update ocr-detection module

* Update ocr-detection module

* Remove change-log
2022-10-23 14:05:13 +08:00

76 lines
2.3 KiB
Python

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
def opListFuse(ops):
""" 算子融合 """
fuseOpList = [
'relu',
'relu6',
'leaky_relu',
'scale',
'sigmoid',
'hard_sigmoid',
'pow',
'sqrt',
'tanh',
'hard_swish',
'dropout'
]
# 判断op是否为单节点
def opExistSingleNode(opName):
name = opName
if name:
nodeNum = 0
for i in range(len(ops)):
op = ops[i]
if 'X' not in op['inputs']:
continue
inputName = op['inputs']['X']
for x in inputName:
if x == name:
nodeNum = nodeNum + 1
return True if nodeNum == 1 else False
else:
return False
for index in reversed(range(len(ops))):
if index > 0:
op = ops[index]
# 兼容paddlelite 算子融合字段
if 'act_type' in op['attrs']:
name = op['attrs']['act_type']
op['attrs']['fuse_opt'] = {}
op['attrs']['fuse_opt'][name] = {}
if name == 'hard_swish':
op['attrs']['fuse_opt'][name]['offset'] = op['attrs']['hard_swish_offset']
op['attrs']['fuse_opt'][name]['scale'] = op['attrs']['hard_swish_scale']
op['attrs']['fuse_opt'][name]['threshold'] = op['attrs']['hard_swish_threshold']
if name == 'relu6':
op['attrs']['fuse_opt'][name]['threshold'] = op['attrs']['fuse_brelu_threshold']
for fuse in fuseOpList:
if op['type'] == fuse:
prevOp = ops[index - 1]
if opExistSingleNode(prevOp['outputs']['Out'][0]) and len(prevOp['outputs']['Out']) == 1 :
prevOp['attrs']['fuse_opt'] = {}
if 'fuse_opt' in op['attrs']:
prevOp['attrs']['fuse_opt'] = op['attrs']['fuse_opt']
del op['attrs']['fuse_opt']
prevOp['attrs']['fuse_opt'][fuse] = op['attrs']
prevOp['outputs']['Out'] = op['outputs']['Out']
del ops[index]