From 6eada4929d570ed7a97e1306b2bf90aae4052b6b Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 22 Dec 2025 11:34:04 +0800 Subject: [PATCH] [Speculative Decoding]Support multi-step mtp with cudagraph (#5624) * support multi-step mtp with cudagraph * fix usage * fix unit test --- fastdeploy/config.py | 26 +++++++++++++------- fastdeploy/worker/gpu_model_runner.py | 34 ++++++++++++--------------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 01dbbc5d8..f802c53a8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -917,17 +917,19 @@ class GraphOptimizationConfig: self.real_shape_to_captured_size[bs] = end self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size - def _set_cudagraph_sizes(self, max_capture_size: int = 0): + def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_per_step: int = 1): """ Calculate a series of candidate capture sizes, and then extract a portion of them as the capture list for the CUDA graph based on user input. """ - # Shape [1, 2, 4, 8, 16, ... 120, 128] - draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)] - # Shape [128, 144, ... 240, 256] - draft_capture_sizes += [16 * i for i in range(9, 17)] - # Shape [256, 288, ... 992, 1024] - draft_capture_sizes += [32 * i for i in range(9, 33)] + # Shape [1, 2, 4, 8, 16, ... 120, 128] * dec_token_per_query_per_step + draft_capture_sizes = [i * dec_token_per_query_per_step for i in [1, 2, 4]] + [ + 8 * i * dec_token_per_query_per_step for i in range(1, 17) + ] + # Shape [128, 144, ... 240, 256] * dec_token_per_query_per_step + draft_capture_sizes += [16 * i * dec_token_per_query_per_step for i in range(9, 17)] + # Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step + draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)] draft_capture_sizes.append(max_capture_size) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) @@ -1598,7 +1600,14 @@ class FDConfig: max_capture_shape = min(512, max_capture_shape) if self.graph_opt_config.cudagraph_capture_sizes is None: - self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape) + dec_token_per_query_per_step = ( + self.speculative_config.num_speculative_tokens + 1 + if self.speculative_config is not None and self.speculative_config.method is not None + else 1 + ) + self.graph_opt_config._set_cudagraph_sizes( + max_capture_size=max_capture_shape, dec_token_per_query_per_step=dec_token_per_query_per_step + ) self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape) self.tokenizer = tokenizer @@ -1762,6 +1771,7 @@ class FDConfig: logger.info( "Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!" ) + if not current_platform.is_cuda() and not current_platform.is_maca(): self.graph_opt_config.use_cudagraph = False logger.info("CUDAGraph currently only support on GPU!") diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index f9fed8bfc..9bf533605 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2135,25 +2135,21 @@ class GPUModelRunner(ModelRunnerBase): ) elif self.speculative_decoding and self.speculative_method == "mtp": # Capture Target Model without bsz 1 - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture target model for mtp") - else: - assert batch_size % 2 == 0 - self._dummy_run( - num_tokens=( - self.scheduler_config.max_num_seqs - * (self.speculative_config.num_speculative_tokens + 1) - if self.scheduler_config.splitwise_role == "decode" - else self.scheduler_config.max_num_batched_tokens - ), - batch_size=int(batch_size / 2), - in_capturing=True, - expected_decode_len=1, - ) - logger.info( - f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}" - ) + for capture_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=( + self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1) + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), + batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), + in_capturing=True, + expected_decode_len=self.speculative_config.num_speculative_tokens, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" + ) if self.graph_opt_config.draft_model_use_cudagraph: # Capture Draft Model without bsz 1 # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph