[Graph Optimization] Add the CUDAGraph usage switch for Draft Model (#4601)

* add draft model using cudagraph switch

* set default as false

* capture draft model in ci

* fix bug
This commit is contained in:
RAM
2025-10-30 11:44:50 +08:00
committed by GitHub
parent cfdd1600a5
commit cd3b7cc392
4 changed files with 29 additions and 22 deletions

View File

@@ -1884,35 +1884,37 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(
f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}"
)
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
if self.graph_opt_config.draft_model_use_cudagraph:
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
)
logger.info(
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
)
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
accept_all_drafts=False,
reject_all_drafts=True,
)
logger.info(
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
)
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=False,
reject_all_drafts=True,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(