From a12d0bc5490777cb62691302660a54132b92f39a Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 18 Aug 2025 13:59:56 +0800 Subject: [PATCH] [Feature][MTP]update multi-draft-token strategy (#3369) * update multi-draft-token strategy * fix format --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- custom_ops/gpu_ops/cpp_extensions.cc | 1 + .../draft_model/draft_model_preprocess.cu | 62 ++++++++++++++----- .../eagle_get_base_model_hidden_states.cu | 31 ++++++---- fastdeploy/spec_decode/mtp.py | 5 +- 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 17911252a..bb2e6944e 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -676,6 +676,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& batch_drop, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu index 0653c8770..1c41750d7 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu @@ -28,6 +28,7 @@ __global__ void process_splitwise_prefill( bool* batch_drop, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, @@ -94,6 +95,7 @@ __global__ void draft_model_preprocess_kernel( bool* batch_drop, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, @@ -113,13 +115,15 @@ __global__ void draft_model_preprocess_kernel( int tid = threadIdx.x; if (tid < bsz) { - auto base_model_step_idx_now = base_model_step_idx[tid]; + const int32_t base_model_step_idx_now = base_model_step_idx[tid]; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; - auto accept_num_now = accept_num[tid]; + const int32_t accept_num_now = accept_num[tid]; auto* input_ids_now = input_ids + tid * input_ids_len; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * base_model_draft_tokens_len; + auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; + const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]; #pragma unroll for (int i = 1; i < base_model_draft_tokens_len; i++) { base_model_draft_tokens_now[i] = -1; @@ -149,25 +153,42 @@ __global__ void draft_model_preprocess_kernel( input_ids_now[position] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder + 1; } - } else if (accept_num_now <= - max_draft_token) /*Accept partial draft tokens*/ { - // Base Model reject stop + } else { if (stop_flags[tid]) { stop_flags[tid] = false; - seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; - step_idx[tid] = base_model_step_idx[tid]; + // TODO: check + seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time; + step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time; } else { - seq_lens_decoder[tid] -= max_draft_token - accept_num_now; - step_idx[tid] -= max_draft_token - accept_num_now; + // 2: Last base model generated token and first MTP token + seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2); + step_idx[tid] -= (base_model_seq_len_this_time - 2); } - int64_t modified_token = accept_tokens_now[accept_num_now - 1]; - draft_tokens_now[0] = modified_token; - seq_lens_this_time[tid] = 1; - - } else /*Accept all draft tokens*/ { - draft_tokens_now[1] = accept_tokens_now[max_draft_token]; - seq_lens_this_time[tid] = 2; + for (int i = 0; i < accept_num_now; i++) { + draft_tokens_now[i] = accept_tokens_now[i]; + } + seq_lens_this_time[tid] = accept_num_now; } + // (liuzichang): Temperary Reserved for debug + // else if (accept_num_now <= + // max_draft_token) /*Accept partial draft tokens*/ { + // // Base Model reject stop + // if (stop_flags[tid]) { + // stop_flags[tid] = false; + // seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; + // step_idx[tid] = base_model_step_idx[tid]; + // } else { + // seq_lens_decoder[tid] -= max_draft_token - accept_num_now; + // step_idx[tid] -= max_draft_token - accept_num_now; + // } + // int64_t modified_token = accept_tokens_now[accept_num_now - 1]; + // draft_tokens_now[0] = modified_token; + // seq_lens_this_time[tid] = 1; + + // } else /*Accept all draft tokens*/ { + // draft_tokens_now[1] = accept_tokens_now[max_draft_token]; + // seq_lens_this_time[tid] = 2; + // } } else { stop_flags[tid] = true; seq_lens_this_time[tid] = 0; @@ -196,6 +217,7 @@ void DispatchRunner( bool* batch_drop, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, @@ -224,6 +246,7 @@ void DispatchRunner( batch_drop, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, @@ -250,6 +273,7 @@ void DispatchRunner( batch_drop, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, @@ -278,6 +302,7 @@ void DispatchTokenMode( bool* batch_drop, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, @@ -306,6 +331,7 @@ void DispatchTokenMode( batch_drop, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, @@ -334,6 +360,7 @@ void DispatchTokenMode( batch_drop, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, @@ -365,6 +392,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& batch_drop, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, @@ -397,6 +425,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const_cast(batch_drop.data()), accept_tokens.data(), accept_num.data(), + base_model_seq_lens_this_time.data(), base_model_seq_lens_encoder.data(), base_model_seq_lens_decoder.data(), base_model_step_idx.data(), @@ -431,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "batch_drop", "accept_tokens", "accept_num", + "base_model_seq_lens_this_time", "base_model_seq_lens_encoder", "base_model_seq_lens_decoder", "base_model_step_idx", diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/eagle_get_base_model_hidden_states.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/eagle_get_base_model_hidden_states.cu index 97d900319..e4b1f1858 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/eagle_get_base_model_hidden_states.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/eagle_get_base_model_hidden_states.cu @@ -61,20 +61,25 @@ __global__ void ComputeOrderKernel( // 4. stopped } else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ { } else { - if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ { -#ifdef DEBUG_EAGLE_KERNEL - printf("batch %d: accept_num <= actual_draft_token_num \n", i); -#endif - position_map[in_offset + accept_num - 1] = out_offset++; - in_offset += cur_base_model_seq_lens_this_time; - } else /*Accept all draft tokens*/ { -#ifdef DEBUG_EAGLE_KERNEL - printf("batch %d: accept_num > actual_draft_token_num \n", i); -#endif - position_map[in_offset + accept_num - 2] = out_offset++; - position_map[in_offset + accept_num - 1] = out_offset++; - in_offset += cur_base_model_seq_lens_this_time; + for (int i = 0; i < accept_num; i++) { + position_map[in_offset++] = out_offset++; } + in_offset += cur_base_model_seq_lens_this_time - accept_num; +// (liuzichang): Temperary Reserved for debug +// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ { +// #ifdef DEBUG_EAGLE_KERNEL +// printf("batch %d: accept_num <= actual_draft_token_num \n", i); +// #endif +// position_map[in_offset + accept_num - 1] = out_offset++; +// in_offset += cur_base_model_seq_lens_this_time; +// } else /*Accept all draft tokens*/ { +// #ifdef DEBUG_EAGLE_KERNEL +// printf("batch %d: accept_num > actual_draft_token_num \n", i); +// #endif +// position_map[in_offset + accept_num - 2] = out_offset++; +// position_map[in_offset + accept_num - 1] = out_offset++; +// in_offset += cur_base_model_seq_lens_this_time; +// } } } output_token_num[0] = out_offset; diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 3033e4146..b6386c601 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -317,7 +317,9 @@ class MTPProposer(Proposer): self.model_inputs["max_len_tensor_cpu"] = None # CPU # Input tokens - self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64") + self.model_inputs["draft_tokens"] = paddle.full( + shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" + ) self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"]) @@ -461,6 +463,7 @@ class MTPProposer(Proposer): self.model_inputs["batch_drop"], self.main_model_inputs["accept_tokens"], self.main_model_inputs["accept_num"], + self.main_model_inputs["seq_lens_this_time"], self.main_model_inputs["seq_lens_encoder"], self.main_model_inputs["seq_lens_decoder"], self.main_model_inputs["step_idx"],