[BugFix] fix mtp logprob bugs in chunk prefill (#5234)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* fix mtp logprob bugs in chunk prefill

* merge code

* fix Request CONFLICT

* Revert "fix Request CONFLICT"

This reverts commit 7a438e4119.

* Revert "merge code"

This reverts commit 3839559b83.

* fix

* remove print

* fix

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
GoldPancake
2025-11-27 11:32:01 +08:00
committed by GitHub
parent cc588b70ab
commit bbcd92c8a0
5 changed files with 42 additions and 9 deletions

View File

@@ -46,9 +46,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& token_num_per_batch,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int message_flag, // Target: 3, Draft: 4
int64_t rank_id) {
if (rank_id > 0) {
int64_t rank_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
auto sampled_token_ids_cpu =
@@ -61,12 +64,17 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
auto cu_batch_token_offset_cpu =
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
float* logprob_scores_data = logprob_scores_cpu.data<float>();
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
static struct msgdata msg_sed;
int msg_queue_id = 1;
@@ -127,7 +135,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.meta[2] = bsz;
int max_num_logprobs = logprob_token_ids.shape()[1];
for (int i = 0; i < bsz; i++) {
int cur_token_num = token_num_per_batch_data[i];
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
cur_token_num = 0;
} else {
cur_token_num = token_num_per_batch_data[i];
}
msg_sed.meta[3 + i] = cur_token_num;
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
@@ -198,6 +211,8 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk)
"token_num_per_batch",
"cu_batch_token_offset",
"not_need_stop",
"seq_lens_decoder",
"prompt_lens",
})
.Attrs({"message_flag: int", "rank_id: int64_t"})
.Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"})
.SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK));