mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* support multi-step mtp with cudagraph * fix usage * fix unit test
This commit is contained in:
@@ -888,17 +888,19 @@ class GraphOptimizationConfig:
|
||||
self.real_shape_to_captured_size[bs] = end
|
||||
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def _set_cudagraph_sizes(self, max_capture_size: int = 0):
|
||||
def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_per_step: int = 1):
|
||||
"""
|
||||
Calculate a series of candidate capture sizes,
|
||||
and then extract a portion of them as the capture list for the CUDA graph based on user input.
|
||||
"""
|
||||
# Shape [1, 2, 4, 8, 16, ... 120, 128]
|
||||
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
|
||||
# Shape [128, 144, ... 240, 256]
|
||||
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
||||
# Shape [256, 288, ... 992, 1024]
|
||||
draft_capture_sizes += [32 * i for i in range(9, 33)]
|
||||
# Shape [1, 2, 4, 8, 16, ... 120, 128] * dec_token_per_query_per_step
|
||||
draft_capture_sizes = [i * dec_token_per_query_per_step for i in [1, 2, 4]] + [
|
||||
8 * i * dec_token_per_query_per_step for i in range(1, 17)
|
||||
]
|
||||
# Shape [128, 144, ... 240, 256] * dec_token_per_query_per_step
|
||||
draft_capture_sizes += [16 * i * dec_token_per_query_per_step for i in range(9, 17)]
|
||||
# Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step
|
||||
draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)]
|
||||
|
||||
draft_capture_sizes.append(max_capture_size)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
@@ -1533,7 +1535,14 @@ class FDConfig:
|
||||
max_capture_shape = min(512, max_capture_shape)
|
||||
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape)
|
||||
dec_token_per_query_per_step = (
|
||||
self.speculative_config.num_speculative_tokens + 1
|
||||
if self.speculative_config is not None and self.speculative_config.method is not None
|
||||
else 1
|
||||
)
|
||||
self.graph_opt_config._set_cudagraph_sizes(
|
||||
max_capture_size=max_capture_shape, dec_token_per_query_per_step=dec_token_per_query_per_step
|
||||
)
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@@ -1941,25 +1941,21 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
elif self.speculative_decoding and self.speculative_method == "mtp":
|
||||
# Capture Target Model without bsz 1
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
if batch_size == 1:
|
||||
logger.info("Skip token_num = 1, when capture target model for mtp")
|
||||
else:
|
||||
assert batch_size % 2 == 0
|
||||
self._dummy_run(
|
||||
num_tokens=(
|
||||
self.scheduler_config.max_num_seqs
|
||||
* (self.speculative_config.num_speculative_tokens + 1)
|
||||
if self.scheduler_config.splitwise_role == "decode"
|
||||
else self.scheduler_config.max_num_batched_tokens
|
||||
),
|
||||
batch_size=int(batch_size / 2),
|
||||
in_capturing=True,
|
||||
expected_decode_len=1,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}"
|
||||
)
|
||||
for capture_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=(
|
||||
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
|
||||
if self.scheduler_config.splitwise_role == "decode"
|
||||
else self.scheduler_config.max_num_batched_tokens
|
||||
),
|
||||
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
|
||||
in_capturing=True,
|
||||
expected_decode_len=self.speculative_config.num_speculative_tokens,
|
||||
accept_all_drafts=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
|
||||
)
|
||||
if self.graph_opt_config.draft_model_use_cudagraph:
|
||||
# Capture Draft Model without bsz 1
|
||||
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
|
||||
|
||||
Reference in New Issue
Block a user