mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 21:02:24 +08:00
407 lines
13 KiB
Python
407 lines
13 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.ops.gpu import draft_model_preprocess
|
|
|
|
|
|
def process_splitwise_prefill(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
bsz,
|
|
num_model_step,
|
|
base_model_draft_tokens_len,
|
|
truncate_first_token,
|
|
kvcache_scheduler_v1,
|
|
):
|
|
not_stop_flag_sum = 0
|
|
|
|
for tid in range(bsz):
|
|
not_stop_flag = 0
|
|
input_ids_now = input_ids[tid]
|
|
accept_tokens_now = accept_tokens[tid]
|
|
if seq_lens_encoder[tid] > 0:
|
|
not_stop_flag = 1
|
|
seq_len_encoder = seq_lens_encoder[tid]
|
|
stop_flags[tid] = False
|
|
base_model_first_token = accept_tokens_now[0]
|
|
position = seq_len_encoder
|
|
if truncate_first_token:
|
|
input_ids_now[position - 1] = base_model_first_token
|
|
seq_lens_this_time[tid] = seq_len_encoder
|
|
else:
|
|
input_ids_now[position] = base_model_first_token
|
|
seq_lens_this_time[tid] = seq_len_encoder + 1
|
|
else:
|
|
stop_flags[tid] = True
|
|
seq_lens_this_time[tid] = 0
|
|
seq_lens_decoder[tid] = 0
|
|
seq_lens_encoder[tid] = 0
|
|
not_stop_flag = 0
|
|
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
|
|
not_need_stop[0] = not_stop_flag_sum > 0
|
|
|
|
|
|
def draft_model_preprocess_kernel(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
bsz,
|
|
num_model_step,
|
|
base_model_draft_tokens_len,
|
|
truncate_first_token,
|
|
kvcache_scheduler_v1,
|
|
):
|
|
not_stop_flag_sum = 0
|
|
|
|
for tid in range(bsz):
|
|
not_stop_flag = 0
|
|
accept_tokens_now = accept_tokens[tid]
|
|
draft_tokens_now = draft_tokens[tid]
|
|
accept_num_now = accept_num[tid]
|
|
input_ids_now = input_ids[tid]
|
|
base_model_draft_tokens_now = base_model_draft_tokens[tid]
|
|
base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]
|
|
base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]
|
|
pre_ids_now = pre_ids[tid]
|
|
|
|
base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1
|
|
|
|
if kvcache_scheduler_v1:
|
|
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
|
|
stop_flags[tid] = True
|
|
is_block_step[tid] = True
|
|
# Need to continue infer
|
|
else:
|
|
if base_model_stop_flags[tid] and base_model_is_block_step[tid]:
|
|
batch_drop[tid] = True
|
|
stop_flags[tid] = True
|
|
|
|
if not (base_model_stop_flags[tid] or batch_drop[tid]):
|
|
not_stop_flag = 1
|
|
# 1. first token
|
|
if seq_lens_encoder[tid] > 0:
|
|
# Can be extended to first few tokens
|
|
seq_len_encoder = seq_lens_encoder[tid]
|
|
stop_flags[tid] = False
|
|
base_model_first_token = accept_tokens_now[0]
|
|
pre_ids_now[0] = base_model_first_token
|
|
position = seq_len_encoder
|
|
if truncate_first_token:
|
|
input_ids_now[position - 1] = base_model_first_token
|
|
seq_lens_this_time[tid] = seq_len_encoder
|
|
else:
|
|
input_ids_now[position] = base_model_first_token
|
|
seq_lens_this_time[tid] = seq_len_encoder + 1
|
|
else:
|
|
if kvcache_scheduler_v1:
|
|
# 3. try to recover mtp infer in V1 mode
|
|
if not (base_model_is_block_step[tid] and is_block_step[tid]):
|
|
is_block_step[tid] = False
|
|
|
|
if stop_flags[tid]:
|
|
stop_flags[tid] = False
|
|
# TODO: check
|
|
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time
|
|
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time
|
|
else:
|
|
# 2: Last base model generated token and first MTP token
|
|
seq_lens_decoder[tid] -= num_model_step - 1
|
|
step_idx[tid] -= num_model_step - 1
|
|
|
|
for i in range(accept_num_now):
|
|
draft_tokens_now[i] = accept_tokens_now[i]
|
|
pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i)
|
|
accept_token = accept_tokens_now[i]
|
|
pre_ids_now[pre_id_pos] = accept_token
|
|
|
|
seq_lens_this_time[tid] = accept_num_now
|
|
else:
|
|
stop_flags[tid] = True
|
|
seq_lens_this_time[tid] = 0
|
|
seq_lens_decoder[tid] = 0
|
|
seq_lens_encoder[tid] = 0
|
|
not_stop_flag_sum = not_stop_flag_sum + not_stop_flag
|
|
not_need_stop[0] = not_stop_flag_sum > 0
|
|
|
|
|
|
def DispatchRunner(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
bsz,
|
|
num_model_step,
|
|
truncate_first_token,
|
|
splitwise_prefill,
|
|
kvcache_scheduler_v1,
|
|
):
|
|
base_model_draft_tokens_len = base_model_draft_tokens.shape[1]
|
|
if splitwise_prefill:
|
|
process_splitwise_prefill(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
bsz,
|
|
num_model_step,
|
|
base_model_draft_tokens_len,
|
|
truncate_first_token,
|
|
kvcache_scheduler_v1,
|
|
)
|
|
else:
|
|
draft_model_preprocess_kernel(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
bsz,
|
|
num_model_step,
|
|
base_model_draft_tokens_len,
|
|
truncate_first_token,
|
|
kvcache_scheduler_v1,
|
|
)
|
|
|
|
|
|
def draft_model_preprocess_ref(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
num_model_step,
|
|
truncate_first_token,
|
|
splitwise_prefill,
|
|
kvcache_scheduler_v1,
|
|
):
|
|
real_bsz = seq_lens_this_time.shape[0]
|
|
|
|
DispatchRunner(
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
real_bsz,
|
|
num_model_step,
|
|
truncate_first_token,
|
|
splitwise_prefill,
|
|
kvcache_scheduler_v1,
|
|
)
|
|
|
|
|
|
class TestDraftModelPreprocess:
|
|
def _run_tests(self):
|
|
paddle.seed(2022)
|
|
|
|
# Define parameters
|
|
bsz = 10
|
|
draft_tokens_len = 4
|
|
input_ids_len = 100
|
|
max_draft_token = 10
|
|
|
|
truncate_first_token = True
|
|
splitwise_prefill = False
|
|
|
|
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
|
|
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
|
|
stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool")
|
|
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
|
|
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
|
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
|
|
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
|
seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
|
|
seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841
|
|
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
|
|
is_block_step = paddle.zeros([bsz], dtype="bool")
|
|
batch_drop = paddle.zeros([bsz], dtype="bool")
|
|
|
|
# Output tensors
|
|
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
|
accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32")
|
|
base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
|
base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32")
|
|
base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
|
|
base_model_stop_flags = paddle.zeros([bsz], dtype="bool")
|
|
base_model_is_block_step = paddle.zeros([bsz], dtype="bool")
|
|
base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
|
|
# Run the op
|
|
pre_ids = input_ids.clone()
|
|
base_model_seq_lens_this_time = seq_lens_this_time
|
|
num_model_step = max_draft_token
|
|
|
|
kvcache_scheduler_v1 = True
|
|
inputs = (
|
|
draft_tokens,
|
|
input_ids,
|
|
stop_flags,
|
|
seq_lens_this_time,
|
|
seq_lens_encoder,
|
|
seq_lens_decoder,
|
|
step_idx,
|
|
not_need_stop,
|
|
is_block_step,
|
|
batch_drop,
|
|
pre_ids,
|
|
accept_tokens,
|
|
accept_num,
|
|
base_model_seq_lens_this_time,
|
|
base_model_seq_lens_encoder,
|
|
base_model_seq_lens_decoder,
|
|
base_model_step_idx,
|
|
base_model_stop_flags,
|
|
base_model_is_block_step,
|
|
base_model_draft_tokens,
|
|
num_model_step,
|
|
truncate_first_token,
|
|
splitwise_prefill,
|
|
kvcache_scheduler_v1,
|
|
)
|
|
# inplace modify, need to clone inputs
|
|
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
|
|
draft_model_preprocess_ref(*inputs)
|
|
draft_model_preprocess(*inputs_clone)
|
|
return inputs, inputs_clone
|
|
|
|
def test_draft_model_preprocess(self):
|
|
results1, results2 = self._run_tests()
|
|
np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens
|
|
np.testing.assert_allclose(results1[1], results2[1]) # input_ids
|
|
np.testing.assert_allclose(results1[2], results2[2]) # stop_flags
|
|
np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time
|
|
np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens
|
|
np.testing.assert_allclose(results1[12], results2[12]) # accept_num
|
|
np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|