# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import numpy as np import paddle from fastdeploy.model_executor.ops.gpu import draft_model_preprocess def process_splitwise_prefill( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, bsz, num_model_step, base_model_draft_tokens_len, truncate_first_token, kvcache_scheduler_v1, ): not_stop_flag_sum = 0 for tid in range(bsz): not_stop_flag = 0 input_ids_now = input_ids[tid] accept_tokens_now = accept_tokens[tid] if seq_lens_encoder[tid] > 0: not_stop_flag = 1 seq_len_encoder = seq_lens_encoder[tid] stop_flags[tid] = False base_model_first_token = accept_tokens_now[0] position = seq_len_encoder if truncate_first_token: input_ids_now[position - 1] = base_model_first_token seq_lens_this_time[tid] = seq_len_encoder else: input_ids_now[position] = base_model_first_token seq_lens_this_time[tid] = seq_len_encoder + 1 else: stop_flags[tid] = True seq_lens_this_time[tid] = 0 seq_lens_decoder[tid] = 0 seq_lens_encoder[tid] = 0 not_stop_flag = 0 not_stop_flag_sum = not_stop_flag_sum + not_stop_flag not_need_stop[0] = not_stop_flag_sum > 0 def draft_model_preprocess_kernel( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, bsz, num_model_step, base_model_draft_tokens_len, truncate_first_token, kvcache_scheduler_v1, ): not_stop_flag_sum = 0 for tid in range(bsz): not_stop_flag = 0 accept_tokens_now = accept_tokens[tid] draft_tokens_now = draft_tokens[tid] accept_num_now = accept_num[tid] input_ids_now = input_ids[tid] base_model_draft_tokens_now = base_model_draft_tokens[tid] base_model_seq_len_decoder = base_model_seq_lens_decoder[tid] base_model_seq_len_this_time = base_model_seq_lens_this_time[tid] pre_ids_now = pre_ids[tid] base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1 if kvcache_scheduler_v1: if base_model_stop_flags[tid] and base_model_is_block_step[tid]: stop_flags[tid] = True is_block_step[tid] = True # Need to continue infer else: if base_model_stop_flags[tid] and base_model_is_block_step[tid]: batch_drop[tid] = True stop_flags[tid] = True if not (base_model_stop_flags[tid] or batch_drop[tid]): not_stop_flag = 1 # 1. first token if seq_lens_encoder[tid] > 0: # Can be extended to first few tokens seq_len_encoder = seq_lens_encoder[tid] stop_flags[tid] = False base_model_first_token = accept_tokens_now[0] pre_ids_now[0] = base_model_first_token position = seq_len_encoder if truncate_first_token: input_ids_now[position - 1] = base_model_first_token seq_lens_this_time[tid] = seq_len_encoder else: input_ids_now[position] = base_model_first_token seq_lens_this_time[tid] = seq_len_encoder + 1 else: if kvcache_scheduler_v1: # 3. try to recover mtp infer in V1 mode if not (base_model_is_block_step[tid] and is_block_step[tid]): is_block_step[tid] = False if stop_flags[tid]: stop_flags[tid] = False # TODO: check seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time else: # 2: Last base model generated token and first MTP token seq_lens_decoder[tid] -= num_model_step - 1 step_idx[tid] -= num_model_step - 1 for i in range(accept_num_now): draft_tokens_now[i] = accept_tokens_now[i] pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i) accept_token = accept_tokens_now[i] pre_ids_now[pre_id_pos] = accept_token seq_lens_this_time[tid] = accept_num_now else: stop_flags[tid] = True seq_lens_this_time[tid] = 0 seq_lens_decoder[tid] = 0 seq_lens_encoder[tid] = 0 not_stop_flag_sum = not_stop_flag_sum + not_stop_flag not_need_stop[0] = not_stop_flag_sum > 0 def DispatchRunner( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, bsz, num_model_step, truncate_first_token, splitwise_prefill, kvcache_scheduler_v1, ): base_model_draft_tokens_len = base_model_draft_tokens.shape[1] if splitwise_prefill: process_splitwise_prefill( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, bsz, num_model_step, base_model_draft_tokens_len, truncate_first_token, kvcache_scheduler_v1, ) else: draft_model_preprocess_kernel( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, bsz, num_model_step, base_model_draft_tokens_len, truncate_first_token, kvcache_scheduler_v1, ) def draft_model_preprocess_ref( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, num_model_step, truncate_first_token, splitwise_prefill, kvcache_scheduler_v1, ): real_bsz = seq_lens_this_time.shape[0] DispatchRunner( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, real_bsz, num_model_step, truncate_first_token, splitwise_prefill, kvcache_scheduler_v1, ) class TestDraftModelPreprocess: def _run_tests(self): paddle.seed(2022) # Define parameters bsz = 10 draft_tokens_len = 4 input_ids_len = 100 max_draft_token = 10 truncate_first_token = True splitwise_prefill = False draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") step_idx = paddle.randint(0, 100, [bsz], dtype="int64") seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 not_need_stop = paddle.zeros([1], dtype="bool").cpu() is_block_step = paddle.zeros([bsz], dtype="bool") batch_drop = paddle.zeros([bsz], dtype="bool") # Output tensors accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") base_model_stop_flags = paddle.zeros([bsz], dtype="bool") base_model_is_block_step = paddle.zeros([bsz], dtype="bool") base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") # Run the op pre_ids = input_ids.clone() base_model_seq_lens_this_time = seq_lens_this_time num_model_step = max_draft_token kvcache_scheduler_v1 = True inputs = ( draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids, accept_tokens, accept_num, base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, num_model_step, truncate_first_token, splitwise_prefill, kvcache_scheduler_v1, ) # inplace modify, need to clone inputs inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs] draft_model_preprocess_ref(*inputs) draft_model_preprocess(*inputs_clone) return inputs, inputs_clone def test_draft_model_preprocess(self): results1, results2 = self._run_tests() np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens np.testing.assert_allclose(results1[1], results2[1]) # input_ids np.testing.assert_allclose(results1[2], results2[2]) # stop_flags np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens np.testing.assert_allclose(results1[12], results2[12]) # accept_num np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop if __name__ == "__main__": unittest.main()