[MTP]support mtp chunk_prefill_v1 (#4366)

* support mtp chunk_prefill_v1

* fix mtp chunkprefill output, fix unit test

* fix unit test

* fix save_output
This commit is contained in:
freeliuzc
2025-10-15 13:21:32 +08:00
committed by GitHub
parent ffe7af8a97
commit 582aebd48b
11 changed files with 118 additions and 58 deletions

View File

@@ -709,8 +709,11 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
bool save_each_rank);
bool save_each_rank,
bool skip_prefill);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
@@ -719,7 +722,9 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,

View File

@@ -28,9 +28,12 @@
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
int msg_queue_id,
int save_each_rank) {
int save_each_rank,
bool skip_prefill) {
// printf("enter save output");
if (!save_each_rank && rank_id > 0) {
return;
@@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
int* accept_num_data = accept_num_cpu.data<int>();
auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
@@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
msg_sed.mtext[1] = bsz;
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2] < prompt_lens_data[i - 2])) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
@@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
bool save_each_rank) {
bool save_each_rank,
bool skip_prefill) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1, save_each_rank, skip_prefill);
}
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
bool save_each_rank,
bool skip_prefill) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill);
}
PD_BUILD_STATIC_OP(speculate_save_output)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool", "skip_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool", "skip_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));

View File

@@ -19,7 +19,9 @@ __global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t* prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
@@ -44,23 +46,37 @@ __global__ void speculate_schedula_cache(
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
// decoder
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
// prefill
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
stop_flag_now_int = 1;
}
} else {
stop_flag_now_int = 1;
}
@@ -83,7 +99,9 @@ __global__ void speculate_schedula_cache(
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
@@ -109,7 +127,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
prompt_lens.data<int64_t>(),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
@@ -138,7 +158,9 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"prompt_lens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
@@ -153,6 +175,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
@@ -165,6 +188,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},

View File

@@ -20,30 +20,33 @@
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
int tid = threadIdx.x;
if (tid < bs && !stop_flags[tid]) {
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *accept_tokens_now =
accept_tokens + tid * max_draft_tokens;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
// printf("step_idx[tid] %d\n", step_idx[tid]);
if (step_idx[tid] >= 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
// pre_ids_all_now[step_idx[tid] - i]);
if (tid < bs) {
if (!stop_flags[tid]) {
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *accept_tokens_now =
accept_tokens + tid * max_draft_tokens;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
if (step_idx[tid] >= 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];
}
}
} else {
accept_num[tid] = 0;
seq_lens_decoder[tid] = 0;
}
}
}
@@ -67,10 +70,10 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
const_cast<int*>(accept_num.data<int>()),
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
const_cast<int*>(seq_lens_decoder.data<int>()),
step_idx.data<int64_t>(),
bs,
length,
@@ -86,6 +89,9 @@ PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.Outputs({"pre_ids_all_out", "accept_num_out", "seq_lens_decoder_out"})
.SetInplaceMap({
{"pre_ids_all", "pre_ids_all_out"},
{"accept_num", "accept_num_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));

View File

@@ -71,9 +71,6 @@ __global__ void speculate_update(int *seq_lens_encoder,
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}