[BugFix] fix mtp logprob bugs in chunk prefill (#5234)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* fix mtp logprob bugs in chunk prefill

* merge code

* fix Request CONFLICT

* Revert "fix Request CONFLICT"

This reverts commit 7a438e4119.

* Revert "merge code"

This reverts commit 3839559b83.

* fix

* remove print

* fix

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
GoldPancake
2025-11-27 11:32:01 +08:00
committed by GitHub
parent cc588b70ab
commit bbcd92c8a0
5 changed files with 42 additions and 9 deletions

View File

@@ -313,6 +313,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
self.model_inputs["prompt_lens"] = paddle.clone(self.target_model_inputs["prompt_lens"])
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
@@ -705,7 +706,7 @@ class MTPProposer(Proposer):
self.parallel_config.use_ep,
)
def _propose(self, step_use_cudagraph: bool = False):
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
"""
Main process for MTP inference.
Args:
@@ -830,7 +831,12 @@ class MTPProposer(Proposer):
self.model_inputs,
)
if substep == 0 and sampler_output.logprobs_tensors is not None:
if (
not is_dummy_run
and self.parallel_config.tensor_parallel_rank == 0
and substep == 0
and sampler_output.logprobs_tensors is not None
):
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
speculate_save_output_topk(
sampler_output.sampled_token_ids,
@@ -840,8 +846,11 @@ class MTPProposer(Proposer):
self.model_inputs["batch_token_num"][:real_bsz],
self.model_inputs["cu_batch_token_offset"][:real_bsz],
self.model_inputs["not_need_stop"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["prompt_lens"],
4, # mtype
self.local_rank,
self.parallel_config.use_ep,
)
if self.parallel_config.tensor_parallel_size > 1:
@@ -949,10 +958,12 @@ class MTPProposer(Proposer):
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
def _run_impl(
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose(step_use_cudagraph=step_use_cudagraph)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()