mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -29,7 +29,7 @@ for i in range(bs):
|
||||
ids_len = seq_lens[i, 0]
|
||||
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
|
||||
|
||||
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
|
||||
(x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
|
||||
paddle.to_tensor(input_ids),
|
||||
paddle.to_tensor(cum_offset),
|
||||
paddle.to_tensor(token_num),
|
||||
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
|
||||
print("cu_seqlens_q:\n", cu_seqlens_q)
|
||||
print("cu_seqlens_k:\n", cu_seqlens_k)
|
||||
|
||||
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6],
|
||||
"int64")
|
||||
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
|
||||
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
|
||||
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13],
|
||||
"int32")
|
||||
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
|
||||
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
|
||||
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
|
||||
|
||||
assert sum(ref_x_remove_padding -
|
||||
x_remove_padding) == 0, 'Check x_remove_padding failed.'
|
||||
assert sum(ref_cum_offsets_out -
|
||||
cum_offsets_out) == 0, 'Check cum_offsets_out failed.'
|
||||
assert sum(ref_padding_offset -
|
||||
padding_offset) == 0, 'Check padding_offset failed.'
|
||||
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
|
||||
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'
|
||||
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
|
||||
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
|
||||
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
|
||||
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
|
||||
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
|
||||
|
@@ -21,10 +21,15 @@ paddle.seed(2023)
|
||||
|
||||
pre_ids = paddle.to_tensor(
|
||||
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
||||
"int64")
|
||||
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
|
||||
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]],
|
||||
"float32")
|
||||
"int64",
|
||||
)
|
||||
logits = paddle.to_tensor(
|
||||
[
|
||||
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
|
||||
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
|
||||
],
|
||||
"float32",
|
||||
)
|
||||
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
||||
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
|
||||
presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
|
||||
@@ -88,78 +93,536 @@ ref_logits = np.array(
|
||||
)
|
||||
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
||||
print("diff_logits\n", diff_logits)
|
||||
assert diff_logits < 1e-6, 'Check failed.'
|
||||
assert diff_logits < 1e-6, "Check failed."
|
||||
|
||||
pre_ids = paddle.to_tensor(
|
||||
[[
|
||||
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8,
|
||||
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4,
|
||||
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6,
|
||||
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8,
|
||||
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4,
|
||||
7, 5, 5, 7, 6, 9, 3, 9
|
||||
],
|
||||
[
|
||||
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1,
|
||||
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1,
|
||||
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8,
|
||||
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6,
|
||||
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6,
|
||||
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8
|
||||
]])
|
||||
[
|
||||
[
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
5,
|
||||
8,
|
||||
9,
|
||||
3,
|
||||
9,
|
||||
1,
|
||||
8,
|
||||
9,
|
||||
2,
|
||||
3,
|
||||
8,
|
||||
8,
|
||||
9,
|
||||
9,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
6,
|
||||
2,
|
||||
6,
|
||||
8,
|
||||
7,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
8,
|
||||
1,
|
||||
5,
|
||||
7,
|
||||
9,
|
||||
2,
|
||||
2,
|
||||
9,
|
||||
1,
|
||||
4,
|
||||
9,
|
||||
8,
|
||||
5,
|
||||
8,
|
||||
5,
|
||||
7,
|
||||
3,
|
||||
6,
|
||||
4,
|
||||
4,
|
||||
9,
|
||||
9,
|
||||
8,
|
||||
5,
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
9,
|
||||
4,
|
||||
8,
|
||||
1,
|
||||
9,
|
||||
6,
|
||||
9,
|
||||
2,
|
||||
2,
|
||||
7,
|
||||
2,
|
||||
2,
|
||||
9,
|
||||
4,
|
||||
6,
|
||||
4,
|
||||
6,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
9,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
5,
|
||||
7,
|
||||
9,
|
||||
4,
|
||||
2,
|
||||
5,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
5,
|
||||
5,
|
||||
4,
|
||||
4,
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
7,
|
||||
1,
|
||||
2,
|
||||
9,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
6,
|
||||
7,
|
||||
7,
|
||||
4,
|
||||
9,
|
||||
9,
|
||||
7,
|
||||
5,
|
||||
1,
|
||||
8,
|
||||
9,
|
||||
8,
|
||||
8,
|
||||
5,
|
||||
4,
|
||||
6,
|
||||
4,
|
||||
7,
|
||||
5,
|
||||
5,
|
||||
7,
|
||||
6,
|
||||
9,
|
||||
3,
|
||||
9,
|
||||
],
|
||||
[
|
||||
7,
|
||||
8,
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
7,
|
||||
6,
|
||||
3,
|
||||
5,
|
||||
3,
|
||||
8,
|
||||
3,
|
||||
1,
|
||||
9,
|
||||
7,
|
||||
1,
|
||||
1,
|
||||
9,
|
||||
5,
|
||||
4,
|
||||
9,
|
||||
6,
|
||||
1,
|
||||
9,
|
||||
3,
|
||||
8,
|
||||
3,
|
||||
9,
|
||||
9,
|
||||
6,
|
||||
4,
|
||||
2,
|
||||
8,
|
||||
5,
|
||||
3,
|
||||
1,
|
||||
6,
|
||||
9,
|
||||
1,
|
||||
3,
|
||||
9,
|
||||
8,
|
||||
1,
|
||||
7,
|
||||
5,
|
||||
1,
|
||||
5,
|
||||
1,
|
||||
8,
|
||||
7,
|
||||
4,
|
||||
5,
|
||||
9,
|
||||
8,
|
||||
7,
|
||||
4,
|
||||
7,
|
||||
3,
|
||||
6,
|
||||
4,
|
||||
6,
|
||||
6,
|
||||
5,
|
||||
5,
|
||||
2,
|
||||
9,
|
||||
9,
|
||||
5,
|
||||
8,
|
||||
8,
|
||||
4,
|
||||
8,
|
||||
2,
|
||||
8,
|
||||
1,
|
||||
3,
|
||||
9,
|
||||
1,
|
||||
8,
|
||||
5,
|
||||
8,
|
||||
3,
|
||||
8,
|
||||
8,
|
||||
2,
|
||||
7,
|
||||
3,
|
||||
7,
|
||||
5,
|
||||
7,
|
||||
2,
|
||||
6,
|
||||
3,
|
||||
5,
|
||||
1,
|
||||
4,
|
||||
6,
|
||||
1,
|
||||
9,
|
||||
8,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
7,
|
||||
6,
|
||||
2,
|
||||
6,
|
||||
5,
|
||||
1,
|
||||
5,
|
||||
6,
|
||||
2,
|
||||
1,
|
||||
6,
|
||||
4,
|
||||
7,
|
||||
7,
|
||||
3,
|
||||
8,
|
||||
5,
|
||||
1,
|
||||
9,
|
||||
1,
|
||||
2,
|
||||
8,
|
||||
6,
|
||||
8,
|
||||
],
|
||||
]
|
||||
)
|
||||
logits = paddle.to_tensor(
|
||||
[[
|
||||
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748,
|
||||
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807,
|
||||
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218,
|
||||
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679,
|
||||
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888,
|
||||
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911,
|
||||
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448,
|
||||
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415,
|
||||
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104,
|
||||
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064,
|
||||
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672,
|
||||
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840,
|
||||
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613,
|
||||
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254,
|
||||
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189,
|
||||
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407,
|
||||
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923,
|
||||
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550,
|
||||
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222,
|
||||
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845,
|
||||
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070,
|
||||
0.98481447, 0.77367836
|
||||
],
|
||||
[
|
||||
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417,
|
||||
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683,
|
||||
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915,
|
||||
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503,
|
||||
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967,
|
||||
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547,
|
||||
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396,
|
||||
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451,
|
||||
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258,
|
||||
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694,
|
||||
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542,
|
||||
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293,
|
||||
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045,
|
||||
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974,
|
||||
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888,
|
||||
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515,
|
||||
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780,
|
||||
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544,
|
||||
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529,
|
||||
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276,
|
||||
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621,
|
||||
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523,
|
||||
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239,
|
||||
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721,
|
||||
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736,
|
||||
0.87317258, 0.88229448, 0.79217333
|
||||
]])
|
||||
[
|
||||
[
|
||||
0.16274983,
|
||||
0.61470598,
|
||||
0.94366980,
|
||||
0.82005417,
|
||||
0.50752640,
|
||||
0.38316748,
|
||||
0.92648441,
|
||||
0.24050158,
|
||||
0.05461595,
|
||||
0.42218581,
|
||||
0.36270225,
|
||||
0.15464807,
|
||||
0.13614719,
|
||||
0.67509544,
|
||||
0.40315166,
|
||||
0.10671722,
|
||||
0.24832056,
|
||||
0.76091218,
|
||||
0.11598995,
|
||||
0.10962527,
|
||||
0.04688513,
|
||||
0.81536716,
|
||||
0.72259802,
|
||||
0.60476679,
|
||||
0.16701800,
|
||||
0.84160781,
|
||||
0.79649884,
|
||||
0.78021604,
|
||||
0.75329530,
|
||||
0.98587888,
|
||||
0.13421868,
|
||||
0.16027625,
|
||||
0.15269397,
|
||||
0.06228730,
|
||||
0.73856270,
|
||||
0.34721911,
|
||||
0.73683006,
|
||||
0.78178608,
|
||||
0.32068327,
|
||||
0.79906309,
|
||||
0.44214272,
|
||||
0.63330448,
|
||||
0.08016958,
|
||||
0.63367140,
|
||||
0.19788943,
|
||||
0.55346787,
|
||||
0.11142531,
|
||||
0.90518415,
|
||||
0.21236691,
|
||||
0.81587470,
|
||||
0.83752930,
|
||||
0.70979482,
|
||||
0.35684183,
|
||||
0.28715104,
|
||||
0.87162822,
|
||||
0.17679396,
|
||||
0.98725849,
|
||||
0.76129991,
|
||||
0.04090235,
|
||||
0.37181064,
|
||||
0.63317049,
|
||||
0.24689502,
|
||||
0.21126501,
|
||||
0.57617670,
|
||||
0.74346697,
|
||||
0.40613672,
|
||||
0.56907010,
|
||||
0.68556929,
|
||||
0.29032683,
|
||||
0.17866278,
|
||||
0.35165095,
|
||||
0.97015840,
|
||||
0.70785582,
|
||||
0.54259878,
|
||||
0.14712237,
|
||||
0.90483177,
|
||||
0.02094105,
|
||||
0.36411613,
|
||||
0.02495066,
|
||||
0.88874054,
|
||||
0.88895452,
|
||||
0.86216462,
|
||||
0.58062190,
|
||||
0.95583254,
|
||||
0.20553111,
|
||||
0.29870346,
|
||||
0.69652933,
|
||||
0.36861244,
|
||||
0.85316223,
|
||||
0.50240189,
|
||||
0.17566244,
|
||||
0.61080140,
|
||||
0.88203174,
|
||||
0.98675215,
|
||||
0.24344546,
|
||||
0.17213407,
|
||||
0.78160852,
|
||||
0.25165486,
|
||||
0.48188508,
|
||||
0.82812423,
|
||||
0.10199814,
|
||||
0.90475923,
|
||||
0.66907483,
|
||||
0.71910626,
|
||||
0.40660757,
|
||||
0.59460294,
|
||||
0.70212913,
|
||||
0.90841550,
|
||||
0.00329034,
|
||||
0.11290466,
|
||||
0.89654654,
|
||||
0.69114941,
|
||||
0.29473618,
|
||||
0.62027222,
|
||||
0.37333879,
|
||||
0.98911142,
|
||||
0.46510187,
|
||||
0.65914583,
|
||||
0.73022646,
|
||||
0.12790845,
|
||||
0.12817244,
|
||||
0.43015456,
|
||||
0.75011456,
|
||||
0.43562204,
|
||||
0.48086026,
|
||||
0.75587070,
|
||||
0.98481447,
|
||||
0.77367836,
|
||||
],
|
||||
[
|
||||
0.12336024,
|
||||
0.74152875,
|
||||
0.09191196,
|
||||
0.99301219,
|
||||
0.44764417,
|
||||
0.01848883,
|
||||
0.78326035,
|
||||
0.99228370,
|
||||
0.81447607,
|
||||
0.02627683,
|
||||
0.51033205,
|
||||
0.98703283,
|
||||
0.15247856,
|
||||
0.77640921,
|
||||
0.60799915,
|
||||
0.87518770,
|
||||
0.76818430,
|
||||
0.86542630,
|
||||
0.31795895,
|
||||
0.04829503,
|
||||
0.85567141,
|
||||
0.30271924,
|
||||
0.67515039,
|
||||
0.59728831,
|
||||
0.78710967,
|
||||
0.75111693,
|
||||
0.56837374,
|
||||
0.49085775,
|
||||
0.91510201,
|
||||
0.59545547,
|
||||
0.99482232,
|
||||
0.59036905,
|
||||
0.58267909,
|
||||
0.28770933,
|
||||
0.53237396,
|
||||
0.95318258,
|
||||
0.93987304,
|
||||
0.61142951,
|
||||
0.26737869,
|
||||
0.52285451,
|
||||
0.03479086,
|
||||
0.61631846,
|
||||
0.66777998,
|
||||
0.15736090,
|
||||
0.00447258,
|
||||
0.37035006,
|
||||
0.15281211,
|
||||
0.95372260,
|
||||
0.25963321,
|
||||
0.61036694,
|
||||
0.15020694,
|
||||
0.19171195,
|
||||
0.55252832,
|
||||
0.00391038,
|
||||
0.31052542,
|
||||
0.96495175,
|
||||
0.42586124,
|
||||
0.05630261,
|
||||
0.99728668,
|
||||
0.01856293,
|
||||
0.83201504,
|
||||
0.10701843,
|
||||
0.56434178,
|
||||
0.38009524,
|
||||
0.51095045,
|
||||
0.13202040,
|
||||
0.07133843,
|
||||
0.75313550,
|
||||
0.17111187,
|
||||
0.80716974,
|
||||
0.00172165,
|
||||
0.83906764,
|
||||
0.73240769,
|
||||
0.85843354,
|
||||
0.11042888,
|
||||
0.07912333,
|
||||
0.33689004,
|
||||
0.22334915,
|
||||
0.59059596,
|
||||
0.52789515,
|
||||
0.29831955,
|
||||
0.39515004,
|
||||
0.55602801,
|
||||
0.83818001,
|
||||
0.05865780,
|
||||
0.25654668,
|
||||
0.76624149,
|
||||
0.35190639,
|
||||
0.04158346,
|
||||
0.59157544,
|
||||
0.30779791,
|
||||
0.94609004,
|
||||
0.10759670,
|
||||
0.65575141,
|
||||
0.37828529,
|
||||
0.29571742,
|
||||
0.76361233,
|
||||
0.72476572,
|
||||
0.18568406,
|
||||
0.85430276,
|
||||
0.02057583,
|
||||
0.76195669,
|
||||
0.65507215,
|
||||
0.69129735,
|
||||
0.25084621,
|
||||
0.75223947,
|
||||
0.06064088,
|
||||
0.20287007,
|
||||
0.35887691,
|
||||
0.75043523,
|
||||
0.47575447,
|
||||
0.40021798,
|
||||
0.44464844,
|
||||
0.67975360,
|
||||
0.40443239,
|
||||
0.71052992,
|
||||
0.21782248,
|
||||
0.50568426,
|
||||
0.89037591,
|
||||
0.06661721,
|
||||
0.28788096,
|
||||
0.70773387,
|
||||
0.42428264,
|
||||
0.80419677,
|
||||
0.42710736,
|
||||
0.87317258,
|
||||
0.88229448,
|
||||
0.79217333,
|
||||
],
|
||||
]
|
||||
)
|
||||
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
||||
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
||||
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
||||
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
|
||||
print("eos_token_id\n", eos_token_id)
|
||||
|
||||
ref_logits = np.array(
|
||||
[[
|
||||
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280,
|
||||
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450,
|
||||
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113,
|
||||
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603,
|
||||
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060,
|
||||
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541,
|
||||
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545,
|
||||
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062,
|
||||
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366,
|
||||
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470,
|
||||
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395,
|
||||
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191,
|
||||
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210,
|
||||
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380,
|
||||
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446,
|
||||
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091,
|
||||
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629,
|
||||
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825,
|
||||
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235,
|
||||
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291,
|
||||
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053,
|
||||
1.51174140, 1.96962893, 1.54735672
|
||||
[
|
||||
[
|
||||
-10000000000.0,
|
||||
-10000000000.0,
|
||||
1.88733959,
|
||||
1.64010835,
|
||||
1.01505280,
|
||||
0.76633495,
|
||||
1.85296881,
|
||||
0.48100317,
|
||||
0.10923190,
|
||||
0.84437162,
|
||||
0.72540450,
|
||||
0.30929613,
|
||||
0.27229437,
|
||||
1.35019088,
|
||||
0.80630332,
|
||||
0.21343444,
|
||||
0.49664113,
|
||||
1.52182436,
|
||||
0.23197991,
|
||||
0.21925054,
|
||||
0.09377026,
|
||||
1.63073432,
|
||||
1.44519603,
|
||||
1.20953357,
|
||||
0.33403599,
|
||||
1.68321562,
|
||||
1.59299767,
|
||||
1.56043208,
|
||||
1.50659060,
|
||||
1.97175777,
|
||||
0.26843736,
|
||||
0.32055250,
|
||||
0.30538794,
|
||||
0.12457460,
|
||||
1.47712541,
|
||||
0.69443822,
|
||||
1.47366011,
|
||||
1.56357217,
|
||||
0.64136654,
|
||||
1.59812617,
|
||||
0.88428545,
|
||||
1.26660895,
|
||||
0.16033916,
|
||||
1.26734281,
|
||||
0.39577886,
|
||||
1.10693574,
|
||||
0.22285062,
|
||||
1.81036830,
|
||||
0.42473382,
|
||||
1.63174939,
|
||||
1.67505860,
|
||||
1.41958964,
|
||||
0.71368366,
|
||||
0.57430208,
|
||||
1.74325645,
|
||||
0.35358793,
|
||||
1.97451699,
|
||||
1.52259982,
|
||||
0.08180470,
|
||||
0.74362129,
|
||||
1.26634097,
|
||||
0.49379003,
|
||||
0.42253003,
|
||||
1.15235341,
|
||||
1.48693395,
|
||||
0.81227344,
|
||||
1.13814020,
|
||||
1.37113857,
|
||||
0.58065367,
|
||||
0.35732555,
|
||||
0.70330191,
|
||||
1.94031680,
|
||||
1.41571164,
|
||||
1.08519757,
|
||||
0.29424474,
|
||||
1.80966353,
|
||||
0.04188210,
|
||||
0.72823226,
|
||||
0.04990132,
|
||||
1.77748108,
|
||||
1.77790904,
|
||||
1.72432923,
|
||||
1.16124380,
|
||||
1.91166508,
|
||||
0.41106221,
|
||||
0.59740692,
|
||||
1.39305866,
|
||||
0.73722488,
|
||||
1.70632446,
|
||||
1.00480378,
|
||||
0.35132489,
|
||||
1.22160280,
|
||||
1.76406348,
|
||||
1.97350430,
|
||||
0.48689091,
|
||||
0.34426814,
|
||||
1.56321704,
|
||||
0.50330973,
|
||||
0.96377015,
|
||||
1.65624845,
|
||||
0.20399629,
|
||||
1.80951846,
|
||||
1.33814967,
|
||||
1.43821251,
|
||||
0.81321514,
|
||||
1.18920588,
|
||||
1.40425825,
|
||||
1.81683099,
|
||||
0.00658068,
|
||||
0.22580932,
|
||||
1.79309309,
|
||||
1.38229883,
|
||||
0.58947235,
|
||||
1.24054444,
|
||||
0.74667758,
|
||||
1.97822285,
|
||||
0.93020374,
|
||||
1.31829166,
|
||||
1.46045291,
|
||||
0.25581691,
|
||||
0.25634488,
|
||||
0.86030912,
|
||||
1.50022912,
|
||||
0.87124407,
|
||||
0.96172053,
|
||||
1.51174140,
|
||||
1.96962893,
|
||||
1.54735672,
|
||||
],
|
||||
[
|
||||
-10000000000.0,
|
||||
-10000000000.0,
|
||||
-40000.0,
|
||||
3.97204876,
|
||||
1.79057670,
|
||||
0.07395532,
|
||||
3.13304138,
|
||||
3.96913481,
|
||||
3.25790429,
|
||||
-40000.0,
|
||||
2.04132819,
|
||||
3.94813132,
|
||||
0.60991424,
|
||||
3.10563684,
|
||||
2.43199658,
|
||||
3.50075078,
|
||||
3.07273722,
|
||||
3.46170521,
|
||||
1.27183580,
|
||||
0.19318011,
|
||||
3.42268562,
|
||||
1.21087694,
|
||||
2.70060158,
|
||||
2.38915324,
|
||||
3.14843869,
|
||||
3.00446773,
|
||||
2.27349496,
|
||||
1.96343100,
|
||||
3.66040802,
|
||||
2.38182187,
|
||||
3.97928929,
|
||||
2.36147618,
|
||||
2.33071637,
|
||||
1.15083730,
|
||||
2.12949586,
|
||||
3.81273031,
|
||||
3.75949216,
|
||||
2.44571805,
|
||||
1.06951475,
|
||||
2.09141803,
|
||||
0.13916343,
|
||||
2.46527386,
|
||||
2.67111993,
|
||||
0.62944359,
|
||||
0.01789032,
|
||||
1.48140025,
|
||||
0.61124843,
|
||||
3.81489038,
|
||||
1.03853285,
|
||||
2.44146776,
|
||||
0.60082775,
|
||||
0.76684779,
|
||||
2.21011329,
|
||||
0.01564152,
|
||||
1.24210167,
|
||||
3.85980701,
|
||||
1.70344496,
|
||||
0.22521044,
|
||||
3.98914671,
|
||||
0.07425172,
|
||||
3.32806015,
|
||||
0.42807373,
|
||||
2.25736713,
|
||||
1.52038097,
|
||||
2.04380178,
|
||||
0.52808160,
|
||||
0.28535372,
|
||||
3.01254201,
|
||||
0.68444747,
|
||||
3.22867894,
|
||||
0.00688660,
|
||||
3.35627055,
|
||||
2.92963076,
|
||||
3.43373418,
|
||||
0.44171551,
|
||||
0.31649333,
|
||||
1.34756017,
|
||||
0.89339662,
|
||||
2.36238384,
|
||||
2.11158061,
|
||||
1.19327819,
|
||||
1.58060014,
|
||||
2.22411203,
|
||||
3.35272002,
|
||||
0.23463120,
|
||||
1.02618670,
|
||||
3.06496596,
|
||||
1.40762556,
|
||||
0.16633384,
|
||||
2.36630177,
|
||||
1.23119164,
|
||||
3.78436017,
|
||||
0.43038681,
|
||||
2.62300563,
|
||||
1.51314116,
|
||||
1.18286967,
|
||||
3.05444932,
|
||||
2.89906287,
|
||||
0.74273622,
|
||||
3.41721106,
|
||||
0.08230332,
|
||||
3.04782677,
|
||||
2.62028861,
|
||||
2.76518941,
|
||||
1.00338483,
|
||||
3.00895786,
|
||||
0.24256352,
|
||||
0.81148028,
|
||||
1.43550766,
|
||||
3.00174093,
|
||||
1.90301788,
|
||||
1.60087192,
|
||||
1.77859378,
|
||||
2.71901441,
|
||||
1.61772954,
|
||||
2.84211969,
|
||||
0.87128991,
|
||||
2.02273703,
|
||||
3.56150365,
|
||||
0.26646885,
|
||||
1.15152383,
|
||||
2.83093548,
|
||||
1.69713056,
|
||||
3.21678710,
|
||||
1.70842946,
|
||||
3.49269032,
|
||||
3.52917790,
|
||||
3.16869330,
|
||||
],
|
||||
],
|
||||
[
|
||||
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670,
|
||||
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819,
|
||||
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078,
|
||||
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562,
|
||||
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773,
|
||||
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929,
|
||||
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031,
|
||||
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343,
|
||||
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025,
|
||||
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775,
|
||||
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701,
|
||||
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015,
|
||||
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160,
|
||||
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660,
|
||||
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333,
|
||||
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819,
|
||||
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670,
|
||||
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164,
|
||||
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967,
|
||||
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332,
|
||||
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786,
|
||||
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788,
|
||||
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969,
|
||||
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383,
|
||||
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032,
|
||||
3.52917790, 3.16869330
|
||||
]],
|
||||
"float32",
|
||||
)
|
||||
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
||||
print("diff_logits\n", diff_logits)
|
||||
assert diff_logits < 1e-6, 'Check failed.'
|
||||
assert diff_logits < 1e-6, "Check failed."
|
||||
|
@@ -21,19 +21,30 @@ paddle.seed(2023)
|
||||
|
||||
pre_ids_all = paddle.to_tensor(
|
||||
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
||||
"int64")
|
||||
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]],
|
||||
"int64")
|
||||
"int64",
|
||||
)
|
||||
input_ids = paddle.to_tensor(
|
||||
[
|
||||
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
|
||||
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
|
||||
seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
|
||||
seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
|
||||
step_idx = paddle.to_tensor([1, 1], "int64")
|
||||
stop_flags = paddle.to_tensor([0, 1], "bool")
|
||||
print("pre_ids_all\n", pre_ids_all)
|
||||
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time,
|
||||
seq_lens_encoder, seq_lens_decoder, step_idx,
|
||||
stop_flags)
|
||||
set_value_by_flags_and_idx(
|
||||
pre_ids_all,
|
||||
input_ids,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
stop_flags,
|
||||
)
|
||||
print("pre_ids_all\n", pre_ids_all)
|
||||
print("input_ids\n", input_ids)
|
||||
print("seq_lens_this_time\n", seq_lens_this_time)
|
||||
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
|
||||
)
|
||||
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
|
||||
print("diff_pre_ids_all\n", diff_pre_ids_all)
|
||||
assert diff_pre_ids_all == 0, 'Check failed.'
|
||||
assert diff_pre_ids_all == 0, "Check failed."
|
||||
|
@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
|
||||
max_block_num = block_bs * max_seq_len // block_size
|
||||
free_list_len = int(max_block_num * (1 - block_ratio))
|
||||
free_list_len = np.full([1], free_list_len, "int32")
|
||||
free_list = np.arange(max_block_num - 1,
|
||||
max_block_num - free_list_len - 1,
|
||||
-1,
|
||||
dtype="int32")
|
||||
free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
|
||||
|
||||
encoder_block_lens = np.zeros([max_bs], "int32")
|
||||
used_list_len = np.zeros([max_bs], "int32")
|
||||
@@ -53,19 +50,15 @@ encoder_block_id = 0
|
||||
for i in range(bs):
|
||||
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
|
||||
encoder_block_lens[i] = enc_block_num
|
||||
dec_block_num = (seq_lens_decoder[i] + block_size -
|
||||
1) // block_size - enc_block_num
|
||||
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
|
||||
used_list_len[i] = dec_block_num
|
||||
block_tables[i, :enc_block_num] = np.arange(
|
||||
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
|
||||
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
|
||||
encoder_block_id += enc_block_num
|
||||
if dec_block_num > 0:
|
||||
block_tables[
|
||||
i, enc_block_num:enc_block_num +
|
||||
dec_block_num] = free_list[free_list_len[0] - 1 -
|
||||
dec_block_num:free_list_len[0] - 1]
|
||||
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
|
||||
1] = -1
|
||||
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
|
||||
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
|
||||
]
|
||||
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
|
||||
free_list_len[0] -= dec_block_num
|
||||
assert free_list_len[0] >= 0
|
||||
|
||||
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
|
||||
# print("step_idx: ", step_idx)
|
||||
# print("next_tokens: ", next_tokens)
|
||||
|
||||
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
|
||||
seq_lens_encoder, seq_lens_decoder, block_tables,
|
||||
encoder_block_lens, is_block_step, step_block_list, step_lens,
|
||||
recover_block_list, recover_lens, need_block_list, need_block_len,
|
||||
used_list_len, free_list, free_list_len, input_ids, pre_ids,
|
||||
step_idx, next_tokens, first_token_ids, block_size,
|
||||
encoder_decoder_block_num)
|
||||
step_paddle(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_lens,
|
||||
recover_block_list,
|
||||
recover_lens,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
input_ids,
|
||||
pre_ids,
|
||||
step_idx,
|
||||
next_tokens,
|
||||
first_token_ids,
|
||||
block_size,
|
||||
encoder_decoder_block_num,
|
||||
)
|
||||
|
||||
print("-" * 50 + "after step op" + "-" * 50)
|
||||
print("stop_flags: ", stop_flags)
|
||||
|
@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
|
||||
print("topk_ids\n", topk_ids)
|
||||
print("next_tokens\n", next_tokens)
|
||||
print("stop_flags\n", stop_flags)
|
||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
|
||||
False)
|
||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
|
||||
print("topk_ids\n", topk_ids)
|
||||
print("next_tokens\n", next_tokens)
|
||||
print("stop_flags\n", stop_flags)
|
||||
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
|
||||
|
||||
ref_topk_ids = np.array(
|
||||
[
|
||||
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20,
|
||||
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39,
|
||||
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59,
|
||||
60, 0, 0, 63
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
3,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
9,
|
||||
10,
|
||||
0,
|
||||
12,
|
||||
0,
|
||||
-1,
|
||||
15,
|
||||
16,
|
||||
0,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
0,
|
||||
22,
|
||||
23,
|
||||
0,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
-1,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
-1,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
-1,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
46,
|
||||
-1,
|
||||
0,
|
||||
49,
|
||||
50,
|
||||
0,
|
||||
52,
|
||||
53,
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
57,
|
||||
-1,
|
||||
59,
|
||||
60,
|
||||
0,
|
||||
0,
|
||||
63,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
ref_next_tokens = np.array(
|
||||
[
|
||||
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20,
|
||||
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0,
|
||||
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0,
|
||||
0, 63
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
3,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
9,
|
||||
10,
|
||||
0,
|
||||
12,
|
||||
0,
|
||||
0,
|
||||
15,
|
||||
16,
|
||||
0,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
0,
|
||||
22,
|
||||
23,
|
||||
0,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
0,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
46,
|
||||
0,
|
||||
0,
|
||||
49,
|
||||
50,
|
||||
0,
|
||||
52,
|
||||
53,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
57,
|
||||
0,
|
||||
59,
|
||||
60,
|
||||
0,
|
||||
0,
|
||||
63,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
ref_stop_flags = np.array(
|
||||
[
|
||||
True, True, True, True, True, True, True, True, True, False, False,
|
||||
True, False, True, True, False, False, True, False, False, False, True,
|
||||
False, False, True, False, False, False, True, False, False, False,
|
||||
True, True, True, True, True, False, False, False, True, True, True,
|
||||
True, True, True, False, True, True, False, False, True, False, False,
|
||||
True, True, True, False, True, False, False, True, True, False
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
],
|
||||
"bool",
|
||||
)
|
||||
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
||||
print("diff_topk_ids\n", diff_topk_ids)
|
||||
assert diff_topk_ids == 0, 'Check failed.'
|
||||
assert diff_topk_ids == 0, "Check failed."
|
||||
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
||||
print("diff_next_tokens\n", diff_next_tokens)
|
||||
assert diff_next_tokens == 0, 'Check failed.'
|
||||
diff_stop_flags = np.sum(
|
||||
np.abs(
|
||||
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||
assert diff_next_tokens == 0, "Check failed."
|
||||
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||
print("diff_stop_flags\n", diff_stop_flags)
|
||||
assert diff_stop_flags == 0, 'Check failed.'
|
||||
assert diff_stop_flags == 0, "Check failed."
|
||||
|
||||
# test beam_search=True
|
||||
topk_ids = paddle.arange(0, bs, dtype="int64")
|
||||
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
|
||||
print("topk_ids\n", topk_ids)
|
||||
print("next_tokens\n", next_tokens)
|
||||
print("stop_flags\n", stop_flags)
|
||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
|
||||
True)
|
||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
|
||||
print("topk_ids\n", topk_ids)
|
||||
print("next_tokens\n", next_tokens)
|
||||
print("stop_flags\n", stop_flags)
|
||||
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
|
||||
|
||||
ref_topk_ids = np.array(
|
||||
[
|
||||
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19,
|
||||
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0,
|
||||
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58,
|
||||
-1, 60, 61, -1, 63
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
0,
|
||||
6,
|
||||
7,
|
||||
-1,
|
||||
9,
|
||||
10,
|
||||
0,
|
||||
-1,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
0,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
0,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
-1,
|
||||
-1,
|
||||
28,
|
||||
29,
|
||||
0,
|
||||
0,
|
||||
-1,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
41,
|
||||
-1,
|
||||
0,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
0,
|
||||
0,
|
||||
49,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
53,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
58,
|
||||
-1,
|
||||
60,
|
||||
61,
|
||||
-1,
|
||||
63,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
ref_next_tokens = np.array(
|
||||
[
|
||||
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20,
|
||||
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0,
|
||||
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61,
|
||||
0, 63
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
0,
|
||||
6,
|
||||
7,
|
||||
0,
|
||||
9,
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
0,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
0,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
0,
|
||||
0,
|
||||
28,
|
||||
29,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
41,
|
||||
0,
|
||||
0,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
0,
|
||||
0,
|
||||
49,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
53,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
58,
|
||||
0,
|
||||
60,
|
||||
61,
|
||||
0,
|
||||
63,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
ref_stop_flags = np.array(
|
||||
[
|
||||
False, False, False, False, False, True, False, False, True, False,
|
||||
False, True, True, False, False, False, True, False, False, False,
|
||||
False, True, False, False, False, False, True, True, False, False,
|
||||
True, True, True, False, False, False, False, False, True, True, True,
|
||||
False, True, True, False, False, False, True, True, False, True, True,
|
||||
True, False, True, True, True, True, False, True, False, False, True,
|
||||
False
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
],
|
||||
"bool",
|
||||
)
|
||||
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
||||
print("diff_topk_ids\n", diff_topk_ids)
|
||||
assert diff_topk_ids == 0, 'Check failed.'
|
||||
assert diff_topk_ids == 0, "Check failed."
|
||||
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
||||
print("diff_next_tokens\n", diff_next_tokens)
|
||||
assert diff_next_tokens == 0, 'Check failed.'
|
||||
diff_stop_flags = np.sum(
|
||||
np.abs(
|
||||
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||
assert diff_next_tokens == 0, "Check failed."
|
||||
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||
print("diff_stop_flags\n", diff_stop_flags)
|
||||
assert diff_stop_flags == 0, 'Check failed.'
|
||||
assert diff_stop_flags == 0, "Check failed."
|
||||
|
@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
print("is_block_step:\n", is_block_step)
|
||||
|
||||
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder,
|
||||
seq_lens_decoder, input_ids, stop_nums, next_tokens,
|
||||
is_block_step)
|
||||
update_inputs(
|
||||
stop_flags,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
next_tokens,
|
||||
is_block_step,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("stop_flags:\n", stop_flags)
|
||||
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
|
||||
ref_not_need_stop_out = np.array([True])
|
||||
ref_seq_lens_this_time_out = np.array([
|
||||
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
|
||||
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1
|
||||
], "int32")
|
||||
ref_seq_lens_encoder_out = np.array([
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
||||
], "int32")
|
||||
ref_seq_lens_decoder_out = np.array([
|
||||
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0,
|
||||
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0,
|
||||
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
||||
], "int32")
|
||||
input_ids_np[:, 0] = np.array([
|
||||
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5,
|
||||
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7,
|
||||
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8
|
||||
], "int64")
|
||||
ref_seq_lens_this_time_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
ref_seq_lens_encoder_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
ref_seq_lens_decoder_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
0,
|
||||
0,
|
||||
6,
|
||||
0,
|
||||
8,
|
||||
8,
|
||||
10,
|
||||
0,
|
||||
12,
|
||||
12,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
20,
|
||||
22,
|
||||
0,
|
||||
24,
|
||||
24,
|
||||
0,
|
||||
26,
|
||||
28,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
32,
|
||||
32,
|
||||
0,
|
||||
34,
|
||||
0,
|
||||
0,
|
||||
38,
|
||||
0,
|
||||
40,
|
||||
0,
|
||||
0,
|
||||
42,
|
||||
0,
|
||||
0,
|
||||
46,
|
||||
46,
|
||||
48,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
input_ids_np[:, 0] = np.array(
|
||||
[
|
||||
6,
|
||||
5,
|
||||
9,
|
||||
8,
|
||||
6,
|
||||
2,
|
||||
8,
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
3,
|
||||
6,
|
||||
9,
|
||||
8,
|
||||
1,
|
||||
9,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
6,
|
||||
7,
|
||||
6,
|
||||
5,
|
||||
3,
|
||||
5,
|
||||
9,
|
||||
3,
|
||||
6,
|
||||
3,
|
||||
9,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
4,
|
||||
8,
|
||||
7,
|
||||
4,
|
||||
2,
|
||||
3,
|
||||
5,
|
||||
8,
|
||||
4,
|
||||
2,
|
||||
5,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
6,
|
||||
7,
|
||||
4,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
9,
|
||||
7,
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
7,
|
||||
8,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
|
||||
assert not_need_stop.numpy(
|
||||
) == ref_not_need_stop_out, 'Check not_need_stop failed.'
|
||||
assert np.all(seq_lens_this_time.numpy() ==
|
||||
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.'
|
||||
assert np.all(seq_lens_encoder.numpy() ==
|
||||
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
|
||||
assert np.all(seq_lens_decoder.numpy() ==
|
||||
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
|
||||
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'
|
||||
assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
|
||||
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
|
||||
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
|
||||
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
|
||||
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
|
||||
|
@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
|
||||
weight = np.transpose(weight_np, [1, 0]) # n,k
|
||||
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
|
||||
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
|
||||
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (
|
||||
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
|
||||
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
|
||||
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
|
||||
return quanted_weight, weight_scales.astype(np.float32)
|
||||
|
||||
|
||||
def np_quant_weight(weight_np, algo='weight_only_int8'):
|
||||
def np_quant_weight(weight_np, algo="weight_only_int8"):
|
||||
assert weight_np.dtype == np.float32
|
||||
|
||||
if algo == 'weight_only_int4':
|
||||
if algo == "weight_only_int4":
|
||||
return np_quant_weight_int4(weight_np)
|
||||
|
||||
weight = np.transpose(weight_np, [1, 0])
|
||||
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
|
||||
def int8_to_bin(value):
|
||||
if not -128 <= value <= 127:
|
||||
raise ValueError("int8 值必须在 -128 到 127 之间")
|
||||
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零
|
||||
return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
|
||||
|
||||
|
||||
# 1) preparation
|
||||
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
|
||||
qw_np, wscale_np = np_quant_weight(w_np, algo)
|
||||
|
||||
# 3) xpu calculation
|
||||
dtype = 'float32'
|
||||
dtype = "float32"
|
||||
x_pd = paddle.to_tensor(w_np, dtype=dtype)
|
||||
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
|
||||
qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
||||
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
||||
# comparation
|
||||
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
|
||||
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
|
||||
print(
|
||||
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}"
|
||||
)
|
||||
print(
|
||||
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
|
||||
)
|
||||
sum_diff = np.sum(
|
||||
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
|
||||
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
|
||||
print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
|
||||
sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
|
||||
print(f"sum_diff: {sum_diff}")
|
||||
|
Reference in New Issue
Block a user