mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 13:52:26 +08:00
204 lines
6.7 KiB
Python
204 lines
6.7 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License")
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.ops.gpu import update_inputs_v1
|
|
|
|
|
|
def update_inputs_kernel_v1(
|
|
not_need_stop,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_seq_lens_decoder,
|
|
prompt_lens,
|
|
topk_ids,
|
|
input_ids,
|
|
block_tables,
|
|
stop_nums,
|
|
stop_flags,
|
|
is_block_step,
|
|
next_tokens,
|
|
bsz,
|
|
max_bsz,
|
|
input_ids_stride,
|
|
block_num_per_seq,
|
|
block_size,
|
|
):
|
|
stop_flag_now = False
|
|
stop_flag_now_int = np.zeros([max_bsz])
|
|
for thread_idx in range(max_bsz):
|
|
if thread_idx < bsz:
|
|
stop_flag_now = stop_flags[thread_idx]
|
|
stop_flag_now_int[thread_idx] = int(stop_flag_now)
|
|
else:
|
|
stop_flag_now_int[thread_idx] = 1
|
|
|
|
for thread_idx in range(bsz):
|
|
stop_flag_now = stop_flags[thread_idx]
|
|
if stop_flag_now:
|
|
seq_lens_this_time[thread_idx] = 0 # stop at next step
|
|
seq_lens_decoder[thread_idx] = 0
|
|
seq_lens_encoder[thread_idx] = 0
|
|
else:
|
|
if seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]:
|
|
# decoding
|
|
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]
|
|
seq_lens_this_time[thread_idx] = 1
|
|
seq_lens_encoder[thread_idx] = 0
|
|
input_ids_now = input_ids[thread_idx]
|
|
input_ids_now[0] = next_tokens[thread_idx]
|
|
|
|
# to judge whether block is not enough
|
|
block_table_now = block_tables[thread_idx]
|
|
if (
|
|
seq_lens_this_time[thread_idx] != 0
|
|
and block_table_now[int(seq_lens_decoder[thread_idx] / block_size)] == -1
|
|
):
|
|
# should be scheduled by server
|
|
is_block_step[thread_idx] = True
|
|
seq_lens_this_time[thread_idx] = 0
|
|
stop_flags[thread_idx] = True
|
|
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]
|
|
seq_lens_decoder[thread_idx] = 0
|
|
stop_flag_now_int[thread_idx] = 1
|
|
else:
|
|
stop_flags[thread_idx] = True
|
|
seq_lens_this_time[thread_idx] = 0
|
|
seq_lens_decoder[thread_idx] = 0
|
|
seq_lens_encoder[thread_idx] = 0
|
|
topk_ids[thread_idx] = -1
|
|
stop_flag_now_int[thread_idx] = 1
|
|
stop_sum = np.sum(stop_flag_now_int)
|
|
not_need_stop[0] = stop_sum < stop_nums[0]
|
|
|
|
|
|
def update_inputs_v1_ref(
|
|
stop_flags,
|
|
not_need_stop,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_seq_lens_decoder,
|
|
prompt_lens,
|
|
topk_ids,
|
|
input_ids,
|
|
block_tables,
|
|
stop_nums,
|
|
next_tokens,
|
|
is_block_step,
|
|
block_size,
|
|
):
|
|
max_bsz = stop_flags.shape[0]
|
|
now_bsz = seq_lens_this_time.shape[0]
|
|
input_ids_stride = input_ids.shape[1]
|
|
block_num_per_seq = block_tables.shape[1]
|
|
update_inputs_kernel_v1(
|
|
not_need_stop,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_seq_lens_decoder,
|
|
prompt_lens,
|
|
topk_ids,
|
|
input_ids,
|
|
block_tables,
|
|
stop_nums,
|
|
stop_flags,
|
|
is_block_step,
|
|
next_tokens,
|
|
now_bsz,
|
|
max_bsz,
|
|
input_ids_stride,
|
|
block_num_per_seq,
|
|
block_size,
|
|
)
|
|
|
|
|
|
class TestUpdateInputsV1(unittest.TestCase):
|
|
def test_update_inputs_v1(self):
|
|
np.random.seed(2023)
|
|
|
|
bs = 48
|
|
max_bs = 64
|
|
max_input_length = 100
|
|
|
|
stop_flags = np.random.randint(0, 2, max_bs).astype("bool")
|
|
not_need_stop = np.array([1], "bool")
|
|
seq_lens_this_time = np.zeros([bs], "int32")
|
|
seq_lens_encoder = np.zeros([max_bs], "int32")
|
|
seq_lens_decoder = np.zeros([max_bs], "int32")
|
|
for i in range(bs):
|
|
if i % 2 == 0:
|
|
seq_lens_encoder[i] = i
|
|
seq_lens_this_time[i] = i
|
|
else:
|
|
seq_lens_decoder[i] = i
|
|
seq_lens_this_time[i] = 1
|
|
step_seq_lens_decoder = np.zeros([bs], "int32")
|
|
prompt_lens = np.random.randint(0, 10, [max_bs], dtype="int64")
|
|
topk_ids = np.zeros([bs], "int64")
|
|
input_ids = np.random.randint(1, 10, [max_bs, max_input_length], "int64")
|
|
block_tables = np.zeros([max_bs, 1], "int32")
|
|
stop_nums = np.array([max_bs], "int64")
|
|
next_tokens = np.random.randint(1, 10, [max_bs], "int64")
|
|
is_block_step = np.random.randint(0, 2, [max_bs]).astype("bool")
|
|
|
|
stop_flags = paddle.to_tensor(stop_flags)
|
|
not_need_stop = paddle.to_tensor(not_need_stop, place=paddle.CPUPlace())
|
|
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time)
|
|
seq_lens_encoder = paddle.to_tensor(seq_lens_encoder)
|
|
seq_lens_decoder = paddle.to_tensor(seq_lens_decoder)
|
|
step_seq_lens_decoder = paddle.to_tensor(step_seq_lens_decoder)
|
|
prompt_lens = paddle.to_tensor(prompt_lens)
|
|
topk_ids = paddle.to_tensor(topk_ids)
|
|
input_ids = paddle.to_tensor(input_ids)
|
|
block_tables = paddle.to_tensor(block_tables)
|
|
stop_nums = paddle.to_tensor(stop_nums)
|
|
next_tokens = paddle.to_tensor(next_tokens)
|
|
is_block_step = paddle.to_tensor(is_block_step)
|
|
block_size = 1024
|
|
|
|
inputs = (
|
|
stop_flags,
|
|
not_need_stop,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_seq_lens_decoder,
|
|
prompt_lens,
|
|
topk_ids,
|
|
input_ids,
|
|
block_tables,
|
|
stop_nums,
|
|
next_tokens,
|
|
is_block_step,
|
|
block_size,
|
|
)
|
|
# inplace modify, need to clone inputs
|
|
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
|
|
update_inputs_v1(*inputs)
|
|
update_inputs_v1_ref(*inputs_clone)
|
|
compare_indexs = [1, 2, 3, 4, 5, 8]
|
|
for idx in compare_indexs:
|
|
np.testing.assert_allclose(inputs[idx], inputs_clone[idx])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|