[Executor] CUDA Graph support padding batch (#2844)

* cuda graph support padding batch

* Integrate the startup parameters for the graph optimization backend and provide support for user - defined capture sizes.

* Do not insert max_num_seqs when the user specifies a capture list

* Support set graph optimization config from YAML file

* update cuda graph ci

* fix ci bug

* fix ci bug
This commit is contained in:
RAM
2025-07-16 10:49:01 +08:00
committed by GitHub
parent 61b3997b85
commit 0fad10b35a
30 changed files with 291 additions and 225 deletions

View File

@@ -260,27 +260,84 @@ class DeviceConfig:
if hasattr(self, key):
setattr(self, key, value)
@dataclass
class GraphOptimizationConfig:
def init_with_cudagrpah_size(self,
cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes"""
if self.cudagraph_capture_sizes is None:
self.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
"""
Configuration for compute graph level optimization.
"""
# sort to make sure cudagraph capture sizes are in descending order
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops = Optional[list[str]]
"""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = True
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
def init_with_cudagrpah_size(
self,
max_num_seqs:int = 0
) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size < max_num_seqs]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
# Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
# Pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
@@ -292,68 +349,24 @@ class GraphOptimizationConfig:
self.batch_size_to_captured_size[
self.max_capture_size] = self.max_capture_size
def __init__(self,
enable_static_graph_inference: bool = False,
max_capture_batch_size: int = 64,
args = None):
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
def _set_cudagraph_sizes(
self,
max_num_seqs:int = 0
):
"""
self.graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
Calculate a series of candidate capture batch sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
"""
self.use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
self.cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
# Batch Size [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Batch Size [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(17, 33)]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
self.cudagraph_splitting_ops = Optional[list[str]]
""""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.full_cuda_graph: bool = False
draft_capture_sizes.append(max_num_seqs)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
self.max_capture_size: int = field(default=None, init=False) # type: ignore
self.batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if enable_static_graph_inference:
self.graph_opt_level = 1
class LoadConfig:
"""
@@ -422,3 +435,13 @@ class FDConfig:
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore
def __post_init__(self):
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:
self.graph_opt_config.graph_opt_level = 1