Files
FastDeploy/tests/operators/test_update_inputs_v1.py

204 lines
6.7 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import update_inputs_v1
def update_inputs_kernel_v1(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size,
):
stop_flag_now = False
stop_flag_now_int = np.zeros([max_bsz])
for thread_idx in range(max_bsz):
if thread_idx < bsz:
stop_flag_now = stop_flags[thread_idx]
stop_flag_now_int[thread_idx] = int(stop_flag_now)
else:
stop_flag_now_int[thread_idx] = 1
for thread_idx in range(bsz):
stop_flag_now = stop_flags[thread_idx]
if stop_flag_now:
seq_lens_this_time[thread_idx] = 0 # stop at next step
seq_lens_decoder[thread_idx] = 0
seq_lens_encoder[thread_idx] = 0
else:
if seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]:
# decoding
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]
seq_lens_this_time[thread_idx] = 1
seq_lens_encoder[thread_idx] = 0
input_ids_now = input_ids[thread_idx]
input_ids_now[0] = next_tokens[thread_idx]
# to judge whether block is not enough
block_table_now = block_tables[thread_idx]
if (
seq_lens_this_time[thread_idx] != 0
and block_table_now[int(seq_lens_decoder[thread_idx] / block_size)] == -1
):
# should be scheduled by server
is_block_step[thread_idx] = True
seq_lens_this_time[thread_idx] = 0
stop_flags[thread_idx] = True
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]
seq_lens_decoder[thread_idx] = 0
stop_flag_now_int[thread_idx] = 1
else:
stop_flags[thread_idx] = True
seq_lens_this_time[thread_idx] = 0
seq_lens_decoder[thread_idx] = 0
seq_lens_encoder[thread_idx] = 0
topk_ids[thread_idx] = -1
stop_flag_now_int[thread_idx] = 1
stop_sum = np.sum(stop_flag_now_int)
not_need_stop[0] = stop_sum < stop_nums[0]
def update_inputs_v1_ref(
stop_flags,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
next_tokens,
is_block_step,
block_size,
):
max_bsz = stop_flags.shape[0]
now_bsz = seq_lens_this_time.shape[0]
input_ids_stride = input_ids.shape[1]
block_num_per_seq = block_tables.shape[1]
update_inputs_kernel_v1(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size,
)
class TestUpdateInputsV1(unittest.TestCase):
def test_update_inputs_v1(self):
np.random.seed(2023)
bs = 48
max_bs = 64
max_input_length = 100
stop_flags = np.random.randint(0, 2, max_bs).astype("bool")
not_need_stop = np.array([1], "bool")
seq_lens_this_time = np.zeros([bs], "int32")
seq_lens_encoder = np.zeros([max_bs], "int32")
seq_lens_decoder = np.zeros([max_bs], "int32")
for i in range(bs):
if i % 2 == 0:
seq_lens_encoder[i] = i
seq_lens_this_time[i] = i
else:
seq_lens_decoder[i] = i
seq_lens_this_time[i] = 1
step_seq_lens_decoder = np.zeros([bs], "int32")
prompt_lens = np.random.randint(0, 10, [max_bs], dtype="int64")
topk_ids = np.zeros([bs], "int64")
input_ids = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
block_tables = np.zeros([max_bs, 1], "int32")
stop_nums = np.array([max_bs], "int64")
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")
stop_flags = paddle.to_tensor(stop_flags)
not_need_stop = paddle.to_tensor(not_need_stop, place=paddle.CPUPlace())
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
step_seq_lens_decoder = paddle.to_tensor(step_seq_lens_decoder)
prompt_lens = paddle.to_tensor(prompt_lens)
topk_ids = paddle.to_tensor(topk_ids)
input_ids = paddle.to_tensor(input_ids)
block_tables = paddle.to_tensor(block_tables)
stop_nums = paddle.to_tensor(stop_nums)
next_tokens = paddle.to_tensor(next_tokens)
is_block_step = paddle.to_tensor(is_block_step)
block_size = 1024
inputs = (
stop_flags,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
next_tokens,
is_block_step,
block_size,
)
# inplace modify, need to clone inputs
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
update_inputs_v1(*inputs)
update_inputs_v1_ref(*inputs_clone)
compare_indexs = [1, 2, 3, 4, 5, 8]
for idx in compare_indexs:
np.testing.assert_allclose(inputs[idx], inputs_clone[idx])
if __name__ == "__main__":
unittest.main()