mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
[Other] Refactor js submodule (#415)
* Refactor js submodule * Remove change-log * Update ocr module * Update ocr-detection module * Update ocr-detection module * Remove change-log
This commit is contained in:
273
examples/application/js/converter/rnn.py
Normal file
273
examples/application/js/converter/rnn.py
Normal file
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
def splice_rnn_op(model_info, rnn_index):
|
||||
|
||||
global input_shape
|
||||
global weight_0_shape
|
||||
global weight_1_shape
|
||||
global rnn_input_name
|
||||
ops = model_info['ops']
|
||||
vars = model_info['vars']
|
||||
op = ops[rnn_index]
|
||||
|
||||
rnn_input_name = op['inputs']['Input'][0]
|
||||
rnn_output_name = op['outputs']['Out'][0]
|
||||
|
||||
is_bidirec = 2 if op['attrs']['is_bidirec'] else 1
|
||||
num_layers = op['attrs']['num_layers']
|
||||
hidden_size = op['attrs']['hidden_size']
|
||||
layer_num = num_layers * is_bidirec
|
||||
# concat input最大值
|
||||
max_concat_num = 15
|
||||
|
||||
def concat_mul(index, list, num):
|
||||
global rnn_input_name
|
||||
end = len(list)
|
||||
|
||||
if end < max_concat_num:
|
||||
concat_output_name = 'lstm_' + str(index - 1) + '_' + str(num) + '.tmp_concat'
|
||||
# 非最后一层遍历,将concat作为下一层输入
|
||||
if index < is_bidirec * num_layers - 1:
|
||||
rnn_input_name = concat_output_name
|
||||
|
||||
# 最后一层遍历,将rnn_output_name赋给最后一个concat
|
||||
else:
|
||||
concat_output_name = rnn_output_name
|
||||
|
||||
concat_op = {
|
||||
'attrs': {
|
||||
'axis': 0
|
||||
},
|
||||
'inputs': {
|
||||
'X': []
|
||||
},
|
||||
'outputs': {'Out': [concat_output_name]},
|
||||
'type': 'concat'
|
||||
}
|
||||
|
||||
concat_output_shape = 0
|
||||
|
||||
for x in range(0, end):
|
||||
x_input_name = 'lstm_' + str(index - 1) + '_' + str(list[x]) + '.tmp_concat'
|
||||
concat_op['inputs']['X'].append(x_input_name)
|
||||
concat_output_shape += vars[x_input_name]['shape'][0]
|
||||
|
||||
concat_var = {
|
||||
'name': concat_output_name,
|
||||
'persistable': False,
|
||||
'shape': [concat_output_shape, 1, weight_1_shape[1] * 2]
|
||||
}
|
||||
|
||||
ops.append(concat_op)
|
||||
|
||||
if index < is_bidirec * num_layers - 1:
|
||||
vars[concat_output_name] = concat_var
|
||||
return
|
||||
|
||||
# concat新列表
|
||||
new_list = []
|
||||
|
||||
for i in range(0, end, max_concat_num):
|
||||
if i + max_concat_num > end:
|
||||
for n in range(i, end):
|
||||
new_list.append(list[n])
|
||||
break
|
||||
|
||||
concat_output_name = 'lstm_' + str(index - 1) + '_' + str(num) + '.tmp_concat'
|
||||
# concat_list长度为max_concat_num && 最后一层遍历,将rnn_output_name赋给最后一个concat
|
||||
if end == max_concat_num and index == is_bidirec * num_layers - 1:
|
||||
concat_output_name = rnn_output_name
|
||||
|
||||
concat_op = {
|
||||
'attrs': {
|
||||
'axis': 0
|
||||
},
|
||||
'inputs': {
|
||||
'X': []
|
||||
},
|
||||
'outputs': {'Out': [concat_output_name]},
|
||||
'type': 'concat'
|
||||
}
|
||||
|
||||
concat_output_shape = 0
|
||||
|
||||
for x in range(0, max_concat_num):
|
||||
x_input_name = 'lstm_' + str(index - 1) + '_' + str(list[i + x]) + '.tmp_concat'
|
||||
concat_op['inputs']['X'].append(x_input_name)
|
||||
concat_output_shape += vars[x_input_name]['shape'][0]
|
||||
|
||||
concat_var = {
|
||||
'name': concat_output_name,
|
||||
'persistable': False,
|
||||
'shape': [concat_output_shape, 1, weight_1_shape[1] * 2]
|
||||
}
|
||||
|
||||
ops.append(concat_op)
|
||||
vars[concat_output_name] = concat_var
|
||||
new_list.append(num)
|
||||
|
||||
# 若concat_list长度为max_concat_num,在下一次递归时直接修改rnn_input_name,结束递归,num无需+1
|
||||
if end != max_concat_num:
|
||||
num += 1
|
||||
|
||||
concat_mul(index, new_list, num)
|
||||
|
||||
for index in range(layer_num):
|
||||
last_hidden = op['inputs']['PreState'][0]
|
||||
last_cell = op['inputs']['PreState'][1]
|
||||
weight_list_0 = op['inputs']['WeightList'][index * 2]
|
||||
weight_list_1 = op['inputs']['WeightList'][index * 2 + 1]
|
||||
weight_list_2 = op['inputs']['WeightList'][(index + num_layers * is_bidirec) * 2]
|
||||
weight_list_3 = op['inputs']['WeightList'][(index + num_layers * is_bidirec) * 2 + 1]
|
||||
output_name = 'rnn_origin_' + str(index)
|
||||
input_shape = vars[rnn_input_name]['shape']
|
||||
batch = input_shape[0]
|
||||
|
||||
if vars[weight_list_0]:
|
||||
weight_0_shape = vars[weight_list_0]['shape']
|
||||
|
||||
if vars[weight_list_1]:
|
||||
weight_1_shape = vars[weight_list_1]['shape']
|
||||
|
||||
if batch == 0:
|
||||
continue
|
||||
|
||||
origin_op = {
|
||||
'attrs': {
|
||||
'state_axis': index
|
||||
},
|
||||
'inputs': {
|
||||
'Input': [rnn_input_name],
|
||||
'PreState': [last_hidden],
|
||||
'WeightList': [
|
||||
weight_list_0,
|
||||
weight_list_1,
|
||||
weight_list_2,
|
||||
weight_list_3
|
||||
]
|
||||
},
|
||||
'outputs': {'Out': [output_name]},
|
||||
'type': 'rnn_origin'
|
||||
}
|
||||
|
||||
origin_var = {
|
||||
'name': output_name,
|
||||
'persistable': False,
|
||||
'shape': [input_shape[0], input_shape[1], weight_0_shape[0]]
|
||||
}
|
||||
ops.append(origin_op)
|
||||
vars[output_name] = origin_var
|
||||
|
||||
for bat in range(batch):
|
||||
matmul_output_name = 'lstm_' + str(index) + '_' + str(bat) + '.tmp_matmul'
|
||||
cell_output_name = 'lstm_' + str(index) + '_' + str(bat) + '.tmp_c'
|
||||
hidden_output_name = 'lstm_' + str(index) + '_' + str(bat) + '.tmp_h'
|
||||
|
||||
matmul_op = {
|
||||
'attrs': {
|
||||
'input_axis': bat,
|
||||
'state_axis': index if bat == 0 else 0,
|
||||
'batch': batch,
|
||||
'reverse': False if index % 2 == 0 else True
|
||||
},
|
||||
'inputs': {
|
||||
'Input': [output_name],
|
||||
'PreState': [last_hidden],
|
||||
'WeightList': [weight_list_1]
|
||||
},
|
||||
'outputs': {'Out': [matmul_output_name]},
|
||||
'type': 'rnn_matmul'
|
||||
}
|
||||
|
||||
matmul_var = {
|
||||
'name': matmul_output_name,
|
||||
'persistable': False,
|
||||
'shape': [1, 1, weight_0_shape[0]]
|
||||
}
|
||||
|
||||
ops.append(matmul_op)
|
||||
vars[matmul_output_name] = matmul_var
|
||||
|
||||
cell_op = {
|
||||
'attrs': {
|
||||
'state_axis': index if bat == 0 else 0,
|
||||
'hidden_size': hidden_size
|
||||
},
|
||||
'inputs': {
|
||||
'X': [matmul_output_name],
|
||||
'Y': [last_cell]
|
||||
},
|
||||
'outputs': {'Out': [cell_output_name]},
|
||||
'type': 'rnn_cell'
|
||||
}
|
||||
|
||||
cell_var = {
|
||||
'name': cell_output_name,
|
||||
'persistable': False,
|
||||
'shape': [1, 1, weight_1_shape[1]]
|
||||
}
|
||||
|
||||
ops.append(cell_op)
|
||||
vars[cell_output_name] = cell_var
|
||||
|
||||
hidden_op = {
|
||||
'attrs': {
|
||||
'state_axis': index if bat == 0 else 0,
|
||||
'hidden_size': hidden_size
|
||||
},
|
||||
'inputs': {
|
||||
'X': [matmul_output_name],
|
||||
'Y': [last_cell]
|
||||
},
|
||||
'outputs': {'Out': [hidden_output_name]},
|
||||
'type': 'rnn_hidden'
|
||||
}
|
||||
|
||||
hidden_var = {
|
||||
'name': hidden_output_name,
|
||||
'persistable': False,
|
||||
'shape': [1, 1, weight_1_shape[1]]
|
||||
}
|
||||
|
||||
ops.append(hidden_op)
|
||||
vars[hidden_output_name] = hidden_var
|
||||
|
||||
last_hidden = hidden_output_name
|
||||
last_cell = cell_output_name
|
||||
|
||||
# concat
|
||||
if index % 2 == 1:
|
||||
|
||||
concat_list = []
|
||||
concat_num = 0
|
||||
# concat forword and backword
|
||||
for bat in range(batch):
|
||||
x_input_name_0 = 'lstm_' + str(index - 1) + '_' + str(bat) + '.tmp_h'
|
||||
x_input_name_1 = 'lstm_' + str(index) + '_' + str(batch - bat - 1) + '.tmp_h'
|
||||
concat_output_name = 'lstm_' + str(index - 1) + '_' + str(bat) + '.tmp_concat'
|
||||
concat_op = {
|
||||
'attrs': {
|
||||
'axis': 2
|
||||
},
|
||||
'inputs': {
|
||||
'X': [x_input_name_0, x_input_name_1]
|
||||
},
|
||||
'outputs': {'Out': [concat_output_name]},
|
||||
'type': 'concat'
|
||||
}
|
||||
|
||||
concat_var = {
|
||||
'name': concat_output_name,
|
||||
'persistable': False,
|
||||
'shape': [1, 1, weight_1_shape[1] * 2]
|
||||
}
|
||||
ops.append(concat_op)
|
||||
vars[concat_output_name] = concat_var
|
||||
concat_list.append(bat)
|
||||
concat_num += 1
|
||||
|
||||
concat_mul(index, concat_list, concat_num)
|
||||
|
||||
# 删除rnn op
|
||||
del ops[rnn_index]
|
Reference in New Issue
Block a user