mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Feature][XPU] add custom kernels for mtp (#3537)
This commit is contained in:
351
custom_ops/xpu_ops/test/test_update_inputs.py
Normal file
351
custom_ops/xpu_ops/test/test_update_inputs.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import update_inputs
|
||||
|
||||
np.random.seed(2023)
|
||||
|
||||
bs = 48
|
||||
max_bs = 64
|
||||
max_input_length = 6144
|
||||
|
||||
stop_flags = np.random.randint(0, 2, max_bs).astype("bool")
|
||||
not_need_stop = np.array([1], "bool")
|
||||
seq_lens_this_time = np.zeros([bs], "int32")
|
||||
seq_lens_encoder = np.zeros([max_bs], "int32")
|
||||
seq_lens_decoder = np.zeros([max_bs], "int32")
|
||||
for i in range(bs):
|
||||
if i % 2 == 0:
|
||||
seq_lens_encoder[i] = i
|
||||
seq_lens_this_time[i] = i
|
||||
else:
|
||||
seq_lens_decoder[i] = i
|
||||
seq_lens_this_time[i] = 1
|
||||
input_ids_np = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
|
||||
stop_nums = np.array([max_bs], "int64")
|
||||
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
|
||||
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")
|
||||
|
||||
stop_flags = paddle.to_tensor(stop_flags)
|
||||
not_need_stop = paddle.to_tensor(not_need_stop, place=paddle.CPUPlace())
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
|
||||
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
|
||||
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
|
||||
input_ids = paddle.to_tensor(input_ids_np)
|
||||
stop_nums = paddle.to_tensor(stop_nums)
|
||||
next_tokens = paddle.to_tensor(next_tokens)
|
||||
is_block_step = paddle.to_tensor(is_block_step)
|
||||
|
||||
print("stop_flags:\n", stop_flags)
|
||||
print("not_need_stop:\n", not_need_stop)
|
||||
print("seq_lens_this_time:\n", seq_lens_this_time)
|
||||
print("seq_lens_encoder:\n", seq_lens_encoder)
|
||||
print("seq_lens_decoder:\n", seq_lens_decoder)
|
||||
print("input_ids:\n", input_ids)
|
||||
print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
print("is_block_step:\n", is_block_step)
|
||||
|
||||
update_inputs(
|
||||
stop_flags,
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
next_tokens,
|
||||
is_block_step,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("stop_flags:\n", stop_flags)
|
||||
print("not_need_stop:\n", not_need_stop)
|
||||
print("seq_lens_this_time:\n", seq_lens_this_time)
|
||||
print("seq_lens_encoder:\n", seq_lens_encoder)
|
||||
print("seq_lens_decoder:\n", seq_lens_decoder)
|
||||
print("input_ids:\n", input_ids)
|
||||
print("stop_nums:\n", stop_nums)
|
||||
print("next_tokens:\n", next_tokens)
|
||||
|
||||
ref_not_need_stop_out = np.array([True])
|
||||
ref_seq_lens_this_time_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
ref_seq_lens_encoder_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
ref_seq_lens_decoder_out = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
0,
|
||||
0,
|
||||
6,
|
||||
0,
|
||||
8,
|
||||
8,
|
||||
10,
|
||||
0,
|
||||
12,
|
||||
12,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
20,
|
||||
22,
|
||||
0,
|
||||
24,
|
||||
24,
|
||||
0,
|
||||
26,
|
||||
28,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
32,
|
||||
32,
|
||||
0,
|
||||
34,
|
||||
0,
|
||||
0,
|
||||
38,
|
||||
0,
|
||||
40,
|
||||
0,
|
||||
0,
|
||||
42,
|
||||
0,
|
||||
0,
|
||||
46,
|
||||
46,
|
||||
48,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
"int32",
|
||||
)
|
||||
input_ids_np[:, 0] = np.array(
|
||||
[
|
||||
6,
|
||||
5,
|
||||
9,
|
||||
8,
|
||||
6,
|
||||
2,
|
||||
8,
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
3,
|
||||
6,
|
||||
9,
|
||||
8,
|
||||
1,
|
||||
9,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
6,
|
||||
7,
|
||||
6,
|
||||
5,
|
||||
3,
|
||||
5,
|
||||
9,
|
||||
3,
|
||||
6,
|
||||
3,
|
||||
9,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
4,
|
||||
8,
|
||||
7,
|
||||
4,
|
||||
2,
|
||||
3,
|
||||
5,
|
||||
8,
|
||||
4,
|
||||
2,
|
||||
5,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
6,
|
||||
7,
|
||||
4,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
9,
|
||||
7,
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
7,
|
||||
8,
|
||||
],
|
||||
"int64",
|
||||
)
|
||||
|
||||
assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
|
||||
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
|
||||
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
|
||||
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
|
||||
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
|
Reference in New Issue
Block a user