mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] fix mtp logprob bugs in chunk prefill (#5234)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* fix mtp logprob bugs in chunk prefill * merge code * fix Request CONFLICT * Revert "fix Request CONFLICT" This reverts commit7a438e4119. * Revert "merge code" This reverts commit3839559b83. * fix * remove print * fix --------- Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
@@ -46,9 +46,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& token_num_per_batch,
|
||||
const paddle::Tensor& cu_batch_token_offset,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
int message_flag, // Target: 3, Draft: 4
|
||||
int64_t rank_id) {
|
||||
if (rank_id > 0) {
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto sampled_token_ids_cpu =
|
||||
@@ -61,12 +64,17 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
token_num_per_batch.copy_to(paddle::CPUPlace(), false);
|
||||
auto cu_batch_token_offset_cpu =
|
||||
cu_batch_token_offset.copy_to(paddle::CPUPlace(), false);
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data<int64_t>();
|
||||
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
|
||||
float* logprob_scores_data = logprob_scores_cpu.data<float>();
|
||||
int64_t* logprob_ranks_data = logprob_ranks_cpu.data<int64_t>();
|
||||
int* token_num_per_batch_data = token_num_per_batch_cpu.data<int>();
|
||||
int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data<int>();
|
||||
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
|
||||
|
||||
static struct msgdata msg_sed;
|
||||
int msg_queue_id = 1;
|
||||
@@ -127,7 +135,12 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
msg_sed.meta[2] = bsz;
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
int cur_token_num = token_num_per_batch_data[i];
|
||||
int cur_token_num;
|
||||
if (seq_lens_decoder_data[i] < prompt_lens_data[i]) {
|
||||
cur_token_num = 0;
|
||||
} else {
|
||||
cur_token_num = token_num_per_batch_data[i];
|
||||
}
|
||||
msg_sed.meta[3 + i] = cur_token_num;
|
||||
auto* cur_batch_msg_sed = &msg_sed.mtext[i];
|
||||
int token_offset = cu_batch_token_offset_data[i];
|
||||
@@ -198,6 +211,8 @@ PD_BUILD_STATIC_OP(speculate_save_output_topk)
|
||||
"token_num_per_batch",
|
||||
"cu_batch_token_offset",
|
||||
"not_need_stop",
|
||||
"seq_lens_decoder",
|
||||
"prompt_lens",
|
||||
})
|
||||
.Attrs({"message_flag: int", "rank_id: int64_t"})
|
||||
.Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK));
|
||||
|
||||
@@ -349,8 +349,11 @@ def post_process_specualate(
|
||||
sampler_output.token_num_per_batch,
|
||||
sampler_output.cu_batch_token_offset,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.prompt_lens,
|
||||
3, # mtype
|
||||
model_output.mp_rank,
|
||||
save_each_rank,
|
||||
)
|
||||
|
||||
# Update pre_ids through accept tokens
|
||||
|
||||
@@ -584,6 +584,8 @@ class TokenProcessor:
|
||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
num_emitted_tokens = sum(real_accept_num)
|
||||
if num_emitted_tokens == 0:
|
||||
return
|
||||
self.num_emitted_tokens += num_emitted_tokens
|
||||
|
||||
main_process_metrics.spec_decode_num_accepted_tokens_total.inc(num_accepted_tokens)
|
||||
|
||||
@@ -313,6 +313,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["prompt_lens"] = paddle.clone(self.target_model_inputs["prompt_lens"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||
@@ -705,7 +706,7 @@ class MTPProposer(Proposer):
|
||||
self.parallel_config.use_ep,
|
||||
)
|
||||
|
||||
def _propose(self, step_use_cudagraph: bool = False):
|
||||
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -830,7 +831,12 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs,
|
||||
)
|
||||
|
||||
if substep == 0 and sampler_output.logprobs_tensors is not None:
|
||||
if (
|
||||
not is_dummy_run
|
||||
and self.parallel_config.tensor_parallel_rank == 0
|
||||
and substep == 0
|
||||
and sampler_output.logprobs_tensors is not None
|
||||
):
|
||||
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -840,8 +846,11 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_token_num"][:real_bsz],
|
||||
self.model_inputs["cu_batch_token_offset"][:real_bsz],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["prompt_lens"],
|
||||
4, # mtype
|
||||
self.local_rank,
|
||||
self.parallel_config.use_ep,
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
@@ -949,10 +958,12 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
|
||||
def _run_impl(
|
||||
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
|
||||
):
|
||||
"""Execute Draft Model"""
|
||||
self._prepare_inputs(full_hidden_states)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
@@ -1281,7 +1281,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
self.proposer.run(
|
||||
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||
full_hidden_states=model_output,
|
||||
step_use_cudagraph=self.forward_meta.step_use_cudagraph,
|
||||
is_dummy_run=True,
|
||||
)
|
||||
else:
|
||||
self.proposer.run(share_inputs=self.share_inputs)
|
||||
|
||||
Reference in New Issue
Block a user