mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature][MTP]update multi-draft-token strategy (#3369)
* update multi-draft-token strategy * fix format --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -676,6 +676,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
|
@@ -28,6 +28,7 @@ __global__ void process_splitwise_prefill(
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -94,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -113,13 +115,15 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
||||
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
auto accept_num_now = accept_num[tid];
|
||||
const int32_t accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
@@ -149,25 +153,42 @@ __global__ void draft_model_preprocess_kernel(
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else if (accept_num_now <=
|
||||
max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// Base Model reject stop
|
||||
} else {
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
step_idx[tid] = base_model_step_idx[tid];
|
||||
// TODO: check
|
||||
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time;
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
|
||||
step_idx[tid] -= (base_model_seq_len_this_time - 2);
|
||||
}
|
||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
draft_tokens_now[0] = modified_token;
|
||||
seq_lens_this_time[tid] = 1;
|
||||
|
||||
} else /*Accept all draft tokens*/ {
|
||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
seq_lens_this_time[tid] = 2;
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// else if (accept_num_now <=
|
||||
// max_draft_token) /*Accept partial draft tokens*/ {
|
||||
// // Base Model reject stop
|
||||
// if (stop_flags[tid]) {
|
||||
// stop_flags[tid] = false;
|
||||
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||
// step_idx[tid] = base_model_step_idx[tid];
|
||||
// } else {
|
||||
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||
// step_idx[tid] -= max_draft_token - accept_num_now;
|
||||
// }
|
||||
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||
// draft_tokens_now[0] = modified_token;
|
||||
// seq_lens_this_time[tid] = 1;
|
||||
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||
// seq_lens_this_time[tid] = 2;
|
||||
// }
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
@@ -196,6 +217,7 @@ void DispatchRunner(
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -224,6 +246,7 @@ void DispatchRunner(
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -250,6 +273,7 @@ void DispatchRunner(
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -278,6 +302,7 @@ void DispatchTokenMode(
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
@@ -306,6 +331,7 @@ void DispatchTokenMode(
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -334,6 +360,7 @@ void DispatchTokenMode(
|
||||
batch_drop,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
@@ -365,6 +392,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
@@ -397,6 +425,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
@@ -431,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"batch_drop",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
|
@@ -61,20 +61,25 @@ __global__ void ComputeOrderKernel(
|
||||
// 4. stopped
|
||||
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
||||
} else {
|
||||
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
} else /*Accept all draft tokens*/ {
|
||||
#ifdef DEBUG_EAGLE_KERNEL
|
||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
#endif
|
||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
in_offset += cur_base_model_seq_lens_this_time;
|
||||
for (int i = 0; i < accept_num; i++) {
|
||||
position_map[in_offset++] = out_offset++;
|
||||
}
|
||||
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||
// (liuzichang): Temperary Reserved for debug
|
||||
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// } else /*Accept all draft tokens*/ {
|
||||
// #ifdef DEBUG_EAGLE_KERNEL
|
||||
// printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||
// #endif
|
||||
// position_map[in_offset + accept_num - 2] = out_offset++;
|
||||
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||
// in_offset += cur_base_model_seq_lens_this_time;
|
||||
// }
|
||||
}
|
||||
}
|
||||
output_token_num[0] = out_offset;
|
||||
|
@@ -317,7 +317,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
self.model_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
|
||||
@@ -461,6 +463,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_drop"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["seq_lens_decoder"],
|
||||
self.main_model_inputs["step_idx"],
|
||||
|
Reference in New Issue
Block a user