mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature][MTP]update multi-draft-token strategy (#3369)
* update multi-draft-token strategy * fix format --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -676,6 +676,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
|||||||
const paddle::Tensor& batch_drop,
|
const paddle::Tensor& batch_drop,
|
||||||
const paddle::Tensor& accept_tokens,
|
const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||||
const paddle::Tensor& base_model_step_idx,
|
const paddle::Tensor& base_model_step_idx,
|
||||||
|
@@ -28,6 +28,7 @@ __global__ void process_splitwise_prefill(
|
|||||||
bool* batch_drop,
|
bool* batch_drop,
|
||||||
const int64_t* accept_tokens,
|
const int64_t* accept_tokens,
|
||||||
const int* accept_num,
|
const int* accept_num,
|
||||||
|
const int* base_model_seq_lens_this_time,
|
||||||
const int* base_model_seq_lens_encoder,
|
const int* base_model_seq_lens_encoder,
|
||||||
const int* base_model_seq_lens_decoder,
|
const int* base_model_seq_lens_decoder,
|
||||||
const int64_t* base_model_step_idx,
|
const int64_t* base_model_step_idx,
|
||||||
@@ -94,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
|
|||||||
bool* batch_drop,
|
bool* batch_drop,
|
||||||
const int64_t* accept_tokens,
|
const int64_t* accept_tokens,
|
||||||
const int* accept_num,
|
const int* accept_num,
|
||||||
|
const int* base_model_seq_lens_this_time,
|
||||||
const int* base_model_seq_lens_encoder,
|
const int* base_model_seq_lens_encoder,
|
||||||
const int* base_model_seq_lens_decoder,
|
const int* base_model_seq_lens_decoder,
|
||||||
const int64_t* base_model_step_idx,
|
const int64_t* base_model_step_idx,
|
||||||
@@ -113,13 +115,15 @@ __global__ void draft_model_preprocess_kernel(
|
|||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
if (tid < bsz) {
|
if (tid < bsz) {
|
||||||
auto base_model_step_idx_now = base_model_step_idx[tid];
|
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
|
||||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||||
auto accept_num_now = accept_num[tid];
|
const int32_t accept_num_now = accept_num[tid];
|
||||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||||
auto* base_model_draft_tokens_now =
|
auto* base_model_draft_tokens_now =
|
||||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||||
|
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||||
|
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||||
base_model_draft_tokens_now[i] = -1;
|
base_model_draft_tokens_now[i] = -1;
|
||||||
@@ -149,25 +153,42 @@ __global__ void draft_model_preprocess_kernel(
|
|||||||
input_ids_now[position] = base_model_first_token;
|
input_ids_now[position] = base_model_first_token;
|
||||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||||
}
|
}
|
||||||
} else if (accept_num_now <=
|
} else {
|
||||||
max_draft_token) /*Accept partial draft tokens*/ {
|
|
||||||
// Base Model reject stop
|
|
||||||
if (stop_flags[tid]) {
|
if (stop_flags[tid]) {
|
||||||
stop_flags[tid] = false;
|
stop_flags[tid] = false;
|
||||||
seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
// TODO: check
|
||||||
step_idx[tid] = base_model_step_idx[tid];
|
seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time;
|
||||||
|
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||||
} else {
|
} else {
|
||||||
seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
// 2: Last base model generated token and first MTP token
|
||||||
step_idx[tid] -= max_draft_token - accept_num_now;
|
seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
|
||||||
|
step_idx[tid] -= (base_model_seq_len_this_time - 2);
|
||||||
}
|
}
|
||||||
int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
for (int i = 0; i < accept_num_now; i++) {
|
||||||
draft_tokens_now[0] = modified_token;
|
draft_tokens_now[i] = accept_tokens_now[i];
|
||||||
seq_lens_this_time[tid] = 1;
|
}
|
||||||
|
seq_lens_this_time[tid] = accept_num_now;
|
||||||
} else /*Accept all draft tokens*/ {
|
|
||||||
draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
|
||||||
seq_lens_this_time[tid] = 2;
|
|
||||||
}
|
}
|
||||||
|
// (liuzichang): Temperary Reserved for debug
|
||||||
|
// else if (accept_num_now <=
|
||||||
|
// max_draft_token) /*Accept partial draft tokens*/ {
|
||||||
|
// // Base Model reject stop
|
||||||
|
// if (stop_flags[tid]) {
|
||||||
|
// stop_flags[tid] = false;
|
||||||
|
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
|
||||||
|
// step_idx[tid] = base_model_step_idx[tid];
|
||||||
|
// } else {
|
||||||
|
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
|
||||||
|
// step_idx[tid] -= max_draft_token - accept_num_now;
|
||||||
|
// }
|
||||||
|
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
|
||||||
|
// draft_tokens_now[0] = modified_token;
|
||||||
|
// seq_lens_this_time[tid] = 1;
|
||||||
|
|
||||||
|
// } else /*Accept all draft tokens*/ {
|
||||||
|
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
|
||||||
|
// seq_lens_this_time[tid] = 2;
|
||||||
|
// }
|
||||||
} else {
|
} else {
|
||||||
stop_flags[tid] = true;
|
stop_flags[tid] = true;
|
||||||
seq_lens_this_time[tid] = 0;
|
seq_lens_this_time[tid] = 0;
|
||||||
@@ -196,6 +217,7 @@ void DispatchRunner(
|
|||||||
bool* batch_drop,
|
bool* batch_drop,
|
||||||
const int64_t* accept_tokens,
|
const int64_t* accept_tokens,
|
||||||
const int* accept_num,
|
const int* accept_num,
|
||||||
|
const int* base_model_seq_lens_this_time,
|
||||||
const int* base_model_seq_lens_encoder,
|
const int* base_model_seq_lens_encoder,
|
||||||
const int* base_model_seq_lens_decoder,
|
const int* base_model_seq_lens_decoder,
|
||||||
const int64_t* base_model_step_idx,
|
const int64_t* base_model_step_idx,
|
||||||
@@ -224,6 +246,7 @@ void DispatchRunner(
|
|||||||
batch_drop,
|
batch_drop,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
|
base_model_seq_lens_this_time,
|
||||||
base_model_seq_lens_encoder,
|
base_model_seq_lens_encoder,
|
||||||
base_model_seq_lens_decoder,
|
base_model_seq_lens_decoder,
|
||||||
base_model_step_idx,
|
base_model_step_idx,
|
||||||
@@ -250,6 +273,7 @@ void DispatchRunner(
|
|||||||
batch_drop,
|
batch_drop,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
|
base_model_seq_lens_this_time,
|
||||||
base_model_seq_lens_encoder,
|
base_model_seq_lens_encoder,
|
||||||
base_model_seq_lens_decoder,
|
base_model_seq_lens_decoder,
|
||||||
base_model_step_idx,
|
base_model_step_idx,
|
||||||
@@ -278,6 +302,7 @@ void DispatchTokenMode(
|
|||||||
bool* batch_drop,
|
bool* batch_drop,
|
||||||
const int64_t* accept_tokens,
|
const int64_t* accept_tokens,
|
||||||
const int* accept_num,
|
const int* accept_num,
|
||||||
|
const int* base_model_seq_lens_this_time,
|
||||||
const int* base_model_seq_lens_encoder,
|
const int* base_model_seq_lens_encoder,
|
||||||
const int* base_model_seq_lens_decoder,
|
const int* base_model_seq_lens_decoder,
|
||||||
const int64_t* base_model_step_idx,
|
const int64_t* base_model_step_idx,
|
||||||
@@ -306,6 +331,7 @@ void DispatchTokenMode(
|
|||||||
batch_drop,
|
batch_drop,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
|
base_model_seq_lens_this_time,
|
||||||
base_model_seq_lens_encoder,
|
base_model_seq_lens_encoder,
|
||||||
base_model_seq_lens_decoder,
|
base_model_seq_lens_decoder,
|
||||||
base_model_step_idx,
|
base_model_step_idx,
|
||||||
@@ -334,6 +360,7 @@ void DispatchTokenMode(
|
|||||||
batch_drop,
|
batch_drop,
|
||||||
accept_tokens,
|
accept_tokens,
|
||||||
accept_num,
|
accept_num,
|
||||||
|
base_model_seq_lens_this_time,
|
||||||
base_model_seq_lens_encoder,
|
base_model_seq_lens_encoder,
|
||||||
base_model_seq_lens_decoder,
|
base_model_seq_lens_decoder,
|
||||||
base_model_step_idx,
|
base_model_step_idx,
|
||||||
@@ -365,6 +392,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
|||||||
const paddle::Tensor& batch_drop,
|
const paddle::Tensor& batch_drop,
|
||||||
const paddle::Tensor& accept_tokens,
|
const paddle::Tensor& accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||||
const paddle::Tensor& base_model_step_idx,
|
const paddle::Tensor& base_model_step_idx,
|
||||||
@@ -397,6 +425,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
|||||||
const_cast<bool*>(batch_drop.data<bool>()),
|
const_cast<bool*>(batch_drop.data<bool>()),
|
||||||
accept_tokens.data<int64_t>(),
|
accept_tokens.data<int64_t>(),
|
||||||
accept_num.data<int>(),
|
accept_num.data<int>(),
|
||||||
|
base_model_seq_lens_this_time.data<int>(),
|
||||||
base_model_seq_lens_encoder.data<int>(),
|
base_model_seq_lens_encoder.data<int>(),
|
||||||
base_model_seq_lens_decoder.data<int>(),
|
base_model_seq_lens_decoder.data<int>(),
|
||||||
base_model_step_idx.data<int64_t>(),
|
base_model_step_idx.data<int64_t>(),
|
||||||
@@ -431,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
|||||||
"batch_drop",
|
"batch_drop",
|
||||||
"accept_tokens",
|
"accept_tokens",
|
||||||
"accept_num",
|
"accept_num",
|
||||||
|
"base_model_seq_lens_this_time",
|
||||||
"base_model_seq_lens_encoder",
|
"base_model_seq_lens_encoder",
|
||||||
"base_model_seq_lens_decoder",
|
"base_model_seq_lens_decoder",
|
||||||
"base_model_step_idx",
|
"base_model_step_idx",
|
||||||
|
@@ -61,20 +61,25 @@ __global__ void ComputeOrderKernel(
|
|||||||
// 4. stopped
|
// 4. stopped
|
||||||
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
} else if (cur_base_model_seq_lens_this_time == 0 && cur_seq_lens_this_time == 0) /* end */ {
|
||||||
} else {
|
} else {
|
||||||
if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
for (int i = 0; i < accept_num; i++) {
|
||||||
#ifdef DEBUG_EAGLE_KERNEL
|
position_map[in_offset++] = out_offset++;
|
||||||
printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
|
||||||
#endif
|
|
||||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
|
||||||
in_offset += cur_base_model_seq_lens_this_time;
|
|
||||||
} else /*Accept all draft tokens*/ {
|
|
||||||
#ifdef DEBUG_EAGLE_KERNEL
|
|
||||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
|
||||||
#endif
|
|
||||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
|
||||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
|
||||||
in_offset += cur_base_model_seq_lens_this_time;
|
|
||||||
}
|
}
|
||||||
|
in_offset += cur_base_model_seq_lens_this_time - accept_num;
|
||||||
|
// (liuzichang): Temperary Reserved for debug
|
||||||
|
// if (accept_num <= actual_draft_token_num) /*Accept partial draft tokens*/ {
|
||||||
|
// #ifdef DEBUG_EAGLE_KERNEL
|
||||||
|
// printf("batch %d: accept_num <= actual_draft_token_num \n", i);
|
||||||
|
// #endif
|
||||||
|
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||||
|
// in_offset += cur_base_model_seq_lens_this_time;
|
||||||
|
// } else /*Accept all draft tokens*/ {
|
||||||
|
// #ifdef DEBUG_EAGLE_KERNEL
|
||||||
|
// printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||||
|
// #endif
|
||||||
|
// position_map[in_offset + accept_num - 2] = out_offset++;
|
||||||
|
// position_map[in_offset + accept_num - 1] = out_offset++;
|
||||||
|
// in_offset += cur_base_model_seq_lens_this_time;
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output_token_num[0] = out_offset;
|
output_token_num[0] = out_offset;
|
||||||
|
@@ -317,7 +317,9 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
self.model_inputs["draft_tokens"] = paddle.full(
|
||||||
|
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||||
|
)
|
||||||
|
|
||||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||||
|
|
||||||
@@ -461,6 +463,7 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["batch_drop"],
|
self.model_inputs["batch_drop"],
|
||||||
self.main_model_inputs["accept_tokens"],
|
self.main_model_inputs["accept_tokens"],
|
||||||
self.main_model_inputs["accept_num"],
|
self.main_model_inputs["accept_num"],
|
||||||
|
self.main_model_inputs["seq_lens_this_time"],
|
||||||
self.main_model_inputs["seq_lens_encoder"],
|
self.main_model_inputs["seq_lens_encoder"],
|
||||||
self.main_model_inputs["seq_lens_decoder"],
|
self.main_model_inputs["seq_lens_decoder"],
|
||||||
self.main_model_inputs["step_idx"],
|
self.main_model_inputs["step_idx"],
|
||||||
|
Reference in New Issue
Block a user