[Speculative Decoding]Support multi-step mtp with cudagraph (#5624)

* support multi-step mtp with cudagraph

* fix usage

* fix unit test
This commit is contained in:
freeliuzc
2025-12-22 11:34:04 +08:00
committed by GitHub
parent 4f830aa505
commit 6eada4929d
2 changed files with 33 additions and 27 deletions

View File

@@ -2135,25 +2135,21 @@ class GPUModelRunner(ModelRunnerBase):
)
elif self.speculative_decoding and self.speculative_method == "mtp":
# Capture Target Model without bsz 1
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture target model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs
* (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.scheduler_config.max_num_batched_tokens
),
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=1,
)
logger.info(
f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}"
)
for capture_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.scheduler_config.max_num_batched_tokens
),
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
in_capturing=True,
expected_decode_len=self.speculative_config.num_speculative_tokens,
accept_all_drafts=True,
)
logger.info(
f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
)
if self.graph_opt_config.draft_model_use_cudagraph:
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph