[MTP]support mtp chunk_prefill_v1 (#4365)

* support mtp chunk_prefill_v1

* fix mtp chunkprefill output

* fix mtp chunkprefill output, fix unit test

* fix save_output

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2025-10-15 15:33:59 +08:00
committed by GitHub
parent 55064b8c57
commit c3499875bd
10 changed files with 118 additions and 59 deletions

View File

@@ -1239,6 +1239,7 @@ class GPUModelRunner(ModelRunnerBase):
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
post_process(
@@ -1579,6 +1580,7 @@ class GPUModelRunner(ModelRunnerBase):
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
@@ -1622,7 +1624,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["draft_tokens"],
self.share_inputs["block_tables"],
self.share_inputs["stop_flags"],
self.share_inputs["prompt_lens"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["step_draft_tokens"],