[Speculative Decoding]Support multi-step mtp with cudagraph (#5624)

* support multi-step mtp with cudagraph

* fix usage

* fix unit test
This commit is contained in:
freeliuzc
2025-12-22 11:34:04 +08:00
committed by GitHub
parent 4f830aa505
commit 6eada4929d
2 changed files with 33 additions and 27 deletions

View File

@@ -917,17 +917,19 @@ class GraphOptimizationConfig:
self.real_shape_to_captured_size[bs] = end
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
def _set_cudagraph_sizes(self, max_capture_size: int = 0):
def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_per_step: int = 1):
"""
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
"""
# Shape [1, 2, 4, 8, 16, ... 120, 128]
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
# Shape [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Shape [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(9, 33)]
# Shape [1, 2, 4, 8, 16, ... 120, 128] * dec_token_per_query_per_step
draft_capture_sizes = [i * dec_token_per_query_per_step for i in [1, 2, 4]] + [
8 * i * dec_token_per_query_per_step for i in range(1, 17)
]
# Shape [128, 144, ... 240, 256] * dec_token_per_query_per_step
draft_capture_sizes += [16 * i * dec_token_per_query_per_step for i in range(9, 17)]
# Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step
draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)]
draft_capture_sizes.append(max_capture_size)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
@@ -1598,7 +1600,14 @@ class FDConfig:
max_capture_shape = min(512, max_capture_shape)
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape)
dec_token_per_query_per_step = (
self.speculative_config.num_speculative_tokens + 1
if self.speculative_config is not None and self.speculative_config.method is not None
else 1
)
self.graph_opt_config._set_cudagraph_sizes(
max_capture_size=max_capture_shape, dec_token_per_query_per_step=dec_token_per_query_per_step
)
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape)
self.tokenizer = tokenizer
@@ -1762,6 +1771,7 @@ class FDConfig:
logger.info(
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
if not current_platform.is_cuda() and not current_platform.is_maca():
self.graph_opt_config.use_cudagraph = False
logger.info("CUDAGraph currently only support on GPU!")