[Feature][MTP]update multi-draft-token strategy (#3369)

* update multi-draft-token strategy

* fix format

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2025-08-18 13:59:56 +08:00
committed by GitHub
parent 3ee6053e5d
commit a12d0bc549
4 changed files with 69 additions and 30 deletions

View File

@@ -676,6 +676,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& batch_drop, const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num, const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_step_idx,

View File

@@ -28,6 +28,7 @@ __global__ void process_splitwise_prefill(
bool* batch_drop, bool* batch_drop,
const int64_t* accept_tokens, const int64_t* accept_tokens,
const int* accept_num, const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder, const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx, const int64_t* base_model_step_idx,
@@ -94,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
bool* batch_drop, bool* batch_drop,
const int64_t* accept_tokens, const int64_t* accept_tokens,
const int* accept_num, const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder, const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx, const int64_t* base_model_step_idx,
@@ -113,13 +115,15 @@ __global__ void draft_model_preprocess_kernel(
int tid = threadIdx.x; int tid = threadIdx.x;
if (tid < bsz) { if (tid < bsz) {
auto base_model_step_idx_now = base_model_step_idx[tid]; const int32_t base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
auto accept_num_now = accept_num[tid]; const int32_t accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len; auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now = auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len; base_model_draft_tokens + tid * base_model_draft_tokens_len;
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
#pragma unroll #pragma unroll
for (int i = 1; i < base_model_draft_tokens_len; i++) { for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1; base_model_draft_tokens_now[i] = -1;
@@ -149,25 +153,42 @@ __global__ void draft_model_preprocess_kernel(
input_ids_now[position] = base_model_first_token; input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1; seq_lens_this_time[tid] = seq_len_encoder + 1;
} }
} else if (accept_num_now <= } else {
max_draft_token) /*Accept partial draft tokens*/ {
// Base Model reject stop
if (stop_flags[tid]) { if (stop_flags[tid]) {
stop_flags[tid] = false; stop_flags[tid] = false;
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; // TODO: check
step_idx[tid] = base_model_step_idx[tid]; seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time;
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
} else { } else {
seq_lens_decoder[tid] -= max_draft_token - accept_num_now; // 2: Last base model generated token and first MTP token
step_idx[tid] -= max_draft_token - accept_num_now; seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
step_idx[tid] -= (base_model_seq_len_this_time - 2);
} }
int64_t modified_token = accept_tokens_now[accept_num_now - 1]; for (int i = 0; i < accept_num_now; i++) {
draft_tokens_now[0] = modified_token; draft_tokens_now[i] = accept_tokens_now[i];
seq_lens_this_time[tid] = 1; }
seq_lens_this_time[tid] = accept_num_now;
} else /*Accept all draft tokens*/ {
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
seq_lens_this_time[tid] = 2;
} }
// (liuzichang): Temperary Reserved for debug
// else if (accept_num_now <=
// max_draft_token) /*Accept partial draft tokens*/ {
// // Base Model reject stop
// if (stop_flags[tid]) {
// stop_flags[tid] = false;
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
// step_idx[tid] = base_model_step_idx[tid];
// } else {
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
// step_idx[tid] -= max_draft_token - accept_num_now;
// }
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
// draft_tokens_now[0] = modified_token;
// seq_lens_this_time[tid] = 1;
// } else /*Accept all draft tokens*/ {
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
// seq_lens_this_time[tid] = 2;
// }
} else { } else {
stop_flags[tid] = true; stop_flags[tid] = true;
seq_lens_this_time[tid] = 0; seq_lens_this_time[tid] = 0;
@@ -196,6 +217,7 @@ void DispatchRunner(
bool* batch_drop, bool* batch_drop,
const int64_t* accept_tokens, const int64_t* accept_tokens,
const int* accept_num, const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder, const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx, const int64_t* base_model_step_idx,
@@ -224,6 +246,7 @@ void DispatchRunner(
batch_drop, batch_drop,
accept_tokens, accept_tokens,
accept_num, accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder, base_model_seq_lens_encoder,
base_model_seq_lens_decoder, base_model_seq_lens_decoder,
base_model_step_idx, base_model_step_idx,
@@ -250,6 +273,7 @@ void DispatchRunner(
batch_drop, batch_drop,
accept_tokens, accept_tokens,
accept_num, accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder, base_model_seq_lens_encoder,
base_model_seq_lens_decoder, base_model_seq_lens_decoder,
base_model_step_idx, base_model_step_idx,
@@ -278,6 +302,7 @@ void DispatchTokenMode(
bool* batch_drop, bool* batch_drop,
const int64_t* accept_tokens, const int64_t* accept_tokens,
const int* accept_num, const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder, const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx, const int64_t* base_model_step_idx,
@@ -306,6 +331,7 @@ void DispatchTokenMode(
batch_drop, batch_drop,
accept_tokens, accept_tokens,
accept_num, accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder, base_model_seq_lens_encoder,
base_model_seq_lens_decoder, base_model_seq_lens_decoder,
base_model_step_idx, base_model_step_idx,
@@ -334,6 +360,7 @@ void DispatchTokenMode(
batch_drop, batch_drop,
accept_tokens, accept_tokens,
accept_num, accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder, base_model_seq_lens_encoder,
base_model_seq_lens_decoder, base_model_seq_lens_decoder,
base_model_step_idx, base_model_step_idx,
@@ -365,6 +392,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& batch_drop, const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num, const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_step_idx,
@@ -397,6 +425,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<bool*>(batch_drop.data<bool>()), const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(), accept_tokens.data<int64_t>(),
accept_num.data<int>(), accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(), base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(), base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(), base_model_step_idx.data<int64_t>(),
@@ -431,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"batch_drop", "batch_drop",
"accept_tokens", "accept_tokens",
"accept_num", "accept_num",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder", "base_model_seq_lens_encoder",
"base_model_seq_lens_decoder", "base_model_seq_lens_decoder",
"base_model_step_idx", "base_model_step_idx",

View File

@@ -61,20 +61,25 @@ __global__ void ComputeOrderKernel(
// 4. stopped // 4. stopped
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ { } else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
} else { } else {
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ { for (int i = 0; i < accept_num; i++) {
#ifdef DEBUG_EAGLE_KERNEL position_map[in_offset++] = out_offset++;
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
} }
in_offset += cur_base_model_seq_lens_this_time - accept_num;
// (liuzichang): Temperary Reserved for debug
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
// #endif
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// } else /*Accept all draft tokens*/ {
// #ifdef DEBUG_EAGLE_KERNEL
// printf("batch %d: accept_num > actual_draft_token_num \n", i);
// #endif
// position_map[in_offset + accept_num - 2] = out_offset++;
// position_map[in_offset + accept_num - 1] = out_offset++;
// in_offset += cur_base_model_seq_lens_this_time;
// }
} }
} }
output_token_num[0] = out_offset; output_token_num[0] = out_offset;

View File

@@ -317,7 +317,9 @@ class MTPProposer(Proposer):
self.model_inputs["max_len_tensor_cpu"] = None # CPU self.model_inputs["max_len_tensor_cpu"] = None # CPU
# Input tokens # Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64") self.model_inputs["draft_tokens"] = paddle.full(
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
)
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"]) self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
@@ -461,6 +463,7 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"], self.model_inputs["batch_drop"],
self.main_model_inputs["accept_tokens"], self.main_model_inputs["accept_tokens"],
self.main_model_inputs["accept_num"], self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"], self.main_model_inputs["seq_lens_encoder"],
self.main_model_inputs["seq_lens_decoder"], self.main_model_inputs["seq_lens_decoder"],
self.main_model_inputs["step_idx"], self.main_model_inputs["step_idx"],