mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)
* Support prefill in Cudagraph * Refactor GetBlockShapeAndSplitKVBlock Kernel V2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5 * Solve problem about encoder_num_blocks_x_cpu * Add early-exit mechanism for attention kernel * fix test case about append-attention * Update testcode, Add annotations to related tensors * move get_input_length_list * solve test_code * Add annotations about early-exit for attention kernel * Add annotations about early-exit for attention kernel2 * solve comment * solve mtp --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -580,6 +580,10 @@ class GraphOptimizationConfig:
|
||||
""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
self.cudagraph_only_prefill: bool = False
|
||||
"""When cudagraph_only_prefill is False, only capture decode-only.
|
||||
When cudagraph_only_prefill is True, only capture prefill-only.
|
||||
Now don't support capture both decode-only and prefill-only"""
|
||||
self.full_cuda_graph: bool = True
|
||||
|
||||
self.max_capture_size: int = None
|
||||
@@ -592,13 +596,13 @@ class GraphOptimizationConfig:
|
||||
|
||||
self.check_legality_parameters()
|
||||
|
||||
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
|
||||
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
|
||||
"""
|
||||
Initialize cuda graph capture sizes and
|
||||
pre-compute the mapping from batch size to padded graph size
|
||||
"""
|
||||
# Regular capture sizes
|
||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
|
||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
|
||||
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
||||
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
||||
logger.info(
|
||||
@@ -632,7 +636,7 @@ class GraphOptimizationConfig:
|
||||
# Shape [128, 144, ... 240, 256]
|
||||
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
||||
# Shape [256, 288, ... 992, 1024]
|
||||
draft_capture_sizes += [32 * i for i in range(17, 33)]
|
||||
draft_capture_sizes += [32 * i for i in range(9, 33)]
|
||||
|
||||
draft_capture_sizes.append(max_num_seqs)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
@@ -1140,7 +1144,11 @@ class FDConfig:
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
|
||||
if self.graph_opt_config.cudagraph_only_prefill:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
|
||||
else:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)
|
||||
|
||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||
if self.graph_opt_config.graph_opt_level == 2:
|
||||
|
Reference in New Issue
Block a user