mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] fix mtp logprob bugs in chunk prefill (#5234)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* fix mtp logprob bugs in chunk prefill * merge code * fix Request CONFLICT * Revert "fix Request CONFLICT" This reverts commit7a438e4119. * Revert "merge code" This reverts commit3839559b83. * fix * remove print * fix --------- Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
@@ -46,9 +46,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& token_num_per_batch,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
int message_flag, // Target: 3, Draft: 4
|
||||
int64_t rank_id) {
|
||||
if (rank_id > 0) {
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto sampled_token_ids_cpu =
|
||||
@@ -61,12 +64,17 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
|
||||
auto cu_batch_token_offset_cpu =
|
||||
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
|
||||
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
|
||||
float* logprob_scores_data = logprob_scores_cpu.data<float>();
|
||||
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
|
||||
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
|
||||
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
|
||||
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
|
||||
|
||||
static struct msgdata msg_sed;
|
||||
int msg_queue_id = 1;
|
||||
@@ -127,7 +135,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
msg_sed.meta[2] = bsz;
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int cur_token_num = token_num_per_batch_data[i];
|
||||
int cur_token_num;
|
||||
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
|
||||
cur_token_num = 0;
|
||||
} else {
|
||||
cur_token_num = token_num_per_batch_data[i];
|
||||
}
|
||||
msg_sed.meta[3 + i] = cur_token_num;
|
||||
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
|
||||
int token_offset = cu_batch_token_offset_data[i];
|
||||
@@ -198,6 +211,8 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk)
|
||||
"token_num_per_batch",
|
||||
"cu_batch_token_offset",
|
||||
"not_need_stop",
|
||||
"seq_lens_decoder",
|
||||
"prompt_lens",
|
||||
})
|
||||
.Attrs({"message_flag: int", "rank_id: int64_t"})
|
||||
.Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK));
|
||||
|
||||
Reference in New Issue
Block a user