[Executor] CUDA Graph support padding batch (#2844)

* cuda graph support padding batch

* Integrate the startup parameters for the graph optimization backend and provide support for user - defined capture sizes.

* Do not insert max_num_seqs when the user specifies a capture list

* Support set graph optimization config from YAML file

* update cuda graph ci

* fix ci bug

* fix ci bug
This commit is contained in:
RAM
2025-07-16 10:49:01 +08:00
committed by GitHub
parent 61b3997b85
commit 0fad10b35a
30 changed files with 291 additions and 225 deletions

View File

@@ -430,6 +430,77 @@ class SpeculativeConfig:
llm_logger.info(
"=============================================================")
def __str__(self) -> str:
return self.to_json_string()
class GraphOptimizationConfig:
def __init__(
self,
graph_opt_level: Optional[int] = 0,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
):
"""
Graph Optimization Configuration class.
Attributes:
graph_opt_level: Compute graph optimization level
use_cudagraph: Use CUDA Graph or not
cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph
"""
self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs)
self.graph_opt_level = graph_opt_level
self.use_cudagraph = use_cudagraph
self.cudagraph_capture_sizes = cudagraph_capture_sizes
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({
key: value
for key, value in self.__dict__.items()
})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
graph_opt_level: Optional[int] = None,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
) -> None:
""" Check the legality of parameters passed in from the command line """
if graph_opt_level is not None:
assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if use_cudagraph is not None:
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
if cudagraph_capture_sizes is not None:
assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
for key, value in kwargs.items():
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
def update_use_cudagraph(self, argument:bool):
"""
Unified user specifies the use_cudagraph parameter through two methods,
'--use-cudagraph' and '--graph-optimization-config'
"""
if self.use_cudagraph is None:
# User only set '--use-cudagraph'
self.use_cudagraph = argument
else:
# User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True:
raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.")
argument = self.use_cudagraph
class ParallelConfig:
"""
@@ -573,6 +644,7 @@ class Config:
max_num_batched_tokens: Optional[int] = None,
pod_ips: Optional[List[str]] = None,
speculative_config: Optional[Dict[str, Any]] = None,
graph_optimization_config: Optional[Dict[str, Any]] = None,
use_warmup: bool = False,
engine_worker_queue_port: int = 8002,
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
@@ -584,9 +656,6 @@ class Config:
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
reasoning_parser: str = None,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
enable_custom_all_reduce: bool = False,
@@ -609,6 +678,7 @@ class Config:
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None.
use_warmup (bool): Flag to use warmup. Default is False.
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
@@ -643,9 +713,7 @@ class Config:
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
self.enable_static_graph_inference = enable_static_graph_inference
self.use_cudagraph = use_cudagraph
self.max_capture_batch_size = max_capture_batch_size
self.graph_optimization_config = graph_optimization_config
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.is_master = True