diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml index 14024b565..45fdffb7e 100644 --- a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint8 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml index 14024b565..45fdffb7e 100644 --- a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint8 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml index 010dd3bc3..b18788981 100644 --- a/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml +++ b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint4 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml index eec95559d..cf1960d1f 100644 --- a/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml +++ b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 96 gpu_memory_utilization: 0.9 kv_cache_ratio: 0.71 tensor_parallel_size: 4 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml +++ b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml index 8cdc10498..64cd60e12 100644 --- a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wfp8afp8 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml index 14024b565..45fdffb7e 100644 --- a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint8 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml index 14024b565..45fdffb7e 100644 --- a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint8 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml index 55a37e029..d69702269 100644 --- a/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml @@ -2,4 +2,5 @@ max_model_len: 32768 max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml index 010dd3bc3..b18788981 100644 --- a/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml +++ b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml @@ -3,4 +3,5 @@ max_num_seqs: 128 kv_cache_ratio: 0.75 tensor_parallel_size: 1 quantization: wint4 -enable_static_graph_inference: True +graph_optimization_config: + graph_opt_level: 1 diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 006545df2..92fa483ba 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -260,27 +260,84 @@ class DeviceConfig: if hasattr(self, key): setattr(self, key, value) +@dataclass class GraphOptimizationConfig: - def init_with_cudagrpah_size(self, - cudagraph_capture_sizes: list[int]) -> None: - """To complete the initialization of config, - we need to know the cudagraph sizes""" - if self.cudagraph_capture_sizes is None: - self.cudagraph_capture_sizes = cudagraph_capture_sizes - else: - dedup_sizes = list(set(self.cudagraph_capture_sizes)) - if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, dedup_sizes) - self.cudagraph_capture_sizes = dedup_sizes + """ + Configuration for compute graph level optimization. + """ - # sort to make sure cudagraph capture sizes are in descending order + """The Top-level graph optimization contral corresponds to different backends. + - 0: dyncmic graph + - 1: static graph + - 2: static graph + cinn compilation backend + """ + graph_opt_level: int = 0 + + # CUDA Graph Config + """ Whether to use cudagraph. + - False: cudagraph is not used. + - True: cudagraph is used. + It requires that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + - With dyncmic graph backend: ... + - With static grpah backend: WIP + """ + use_cudagraph: bool = False + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from llm config. + - list[int]: capture sizes are specified as given.""" + cudagraph_capture_sizes: Optional[list[int]] = None + """ Number of warmup runs for cudagraph. """ + cudagraph_num_of_warmups: int = 2 + """Whether to copy input tensors for cudagraph. + If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True.""" + cudagraph_copy_inputs: bool = False + """ In static graph, this is an operation list that does not need to be captured by the CUDA graph. + CudaGraphBackend will split these operations from the static graph. + Example usage: + cudagraph_splitting_ops = ["paddle.unified_attention"] + + Note: If want to use subgraph capture functionality in a dynamic graph, + can manually split the model into multiple layers and apply the @support_cuda_graph decorator + only to the layer where CUDA graph functionality is required. + """ + cudagraph_splitting_ops = Optional[list[str]] + """" Whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. + Thus this flag cannot be used together with splitting_ops.""" + full_cuda_graph: bool = True + + max_capture_size: int = field(default=None, init=False) # type: ignore + batch_size_to_captured_size: dict[int, + int] = field(default=None, + init=False) # type: ignore + # CINN Config ... + + def init_with_cudagrpah_size( + self, + max_num_seqs:int = 0 + ) -> None: + """ + Initialize cuda graph capture sizes and + pre-compute the mapping from batch size to padded graph size + """ + # Regular capture sizes + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size < max_num_seqs] + dedup_sizes = list(set(self.cudagraph_capture_sizes)) + if len(dedup_sizes) < len(self.cudagraph_capture_sizes): + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, dedup_sizes) + self.cudagraph_capture_sizes = dedup_sizes + + # Sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True) self.max_capture_size = self.cudagraph_capture_sizes[ 0] if self.cudagraph_capture_sizes else 0 - # pre-compute the mapping from batch size to padded graph size + # Pre-compute the mapping from batch size to padded graph size self.batch_size_to_captured_size = {} for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]): @@ -292,68 +349,24 @@ class GraphOptimizationConfig: self.batch_size_to_captured_size[ self.max_capture_size] = self.max_capture_size - def __init__(self, - enable_static_graph_inference: bool = False, - max_capture_batch_size: int = 64, - args = None): - """The Top-level graph optimization contral corresponds to different backends. - - 0: dyncmic graph - - 1: static graph - - 2: static graph + cinn compilation backend + def _set_cudagraph_sizes( + self, + max_num_seqs:int = 0 + ): """ - self.graph_opt_level: int = 0 - - # CUDA Graph Config - """ Whether to use cudagraph. - - False: cudagraph is not used. - - True: cudagraph is used. - It requires that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - - With dyncmic graph backend: ... - - With static grpah backend: WIP + Calculate a series of candidate capture batch sizes, + and then extract a portion of them as the capture list for the CUDA graph based on user input. """ - self.use_cudagraph: bool = False - """Sizes to capture cudagraph. - - None (default): capture sizes are inferred from llm config. - - list[int]: capture sizes are specified as given.""" - self.cudagraph_capture_sizes: Optional[list[int]] = None - """ Number of warmup runs for cudagraph. """ - self.cudagraph_num_of_warmups: int = 2 - """Whether to copy input tensors for cudagraph. - If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True.""" - self.cudagraph_copy_inputs: bool = False - """ In static graph, this is an operation list that does not need to be captured by the CUDA graph. - CudaGraphBackend will split these operations from the static graph. - Example usage: - cudagraph_splitting_ops = ["paddle.unified_attention"] + # Batch Size [1, 2, 4, 8, 16, ... 120, 128] + draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)] + # Batch Size [128, 144, ... 240, 256] + draft_capture_sizes += [16 * i for i in range(9, 17)] + # Batch Size [256, 288, ... 992, 1024] + draft_capture_sizes += [32 * i for i in range(17, 33)] - Note: If want to use subgraph capture functionality in a dynamic graph, - can manually split the model into multiple layers and apply the @support_cuda_graph decorator - only to the layer where CUDA graph functionality is required. - """ - self.cudagraph_splitting_ops = Optional[list[str]] - """"whether to use a full cuda graph for the entire forward pass rather than - splitting certain operations such as attention into subgraphs. - Thus this flag cannot be used together with splitting_ops.""" - self.full_cuda_graph: bool = False + draft_capture_sizes.append(max_num_seqs) + self.cudagraph_capture_sizes = sorted(draft_capture_sizes) - self.max_capture_size: int = field(default=None, init=False) # type: ignore - self.batch_size_to_captured_size: dict[int, - int] = field(default=None, - init=False) # type: ignore - - # CINN Config ... - - for key, value in args.items(): - if hasattr(self, key): - setattr(self, key, value) - capture_size = [i for i in range(1, max_capture_batch_size + 1)] - self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size) - #TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn - if enable_static_graph_inference: - self.graph_opt_level = 1 class LoadConfig: """ @@ -422,3 +435,13 @@ class FDConfig: init=True) # type: ignore kv_cache_config: KVCacheConfig = field(default=None, init=True) # type: ignore + + def __post_init__(self): + # Initialize cuda graph capture list + if self.graph_opt_config.cudagraph_capture_sizes is None: + self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) + self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) + + #TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn + if self.graph_opt_config.graph_opt_level == 2: + self.graph_opt_config.graph_opt_level = 1 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 3a09f147c..59d0daf32 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -18,7 +18,8 @@ from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional -from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig, +from fastdeploy.engine.config import (CacheConfig, Config, + GraphOptimizationConfig, ModelConfig, ParallelConfig, SpeculativeConfig, TaskOption) from fastdeploy.scheduler.config import SchedulerConfig @@ -283,20 +284,13 @@ class EngineArgs: """ SplitWise Use, Results Writer Batch Size """ - enable_static_graph_inference: bool = False - """ - Whether to use static mode - """ use_cudagraph: bool = False """ Flags to enable Cuda Graph """ - max_capture_batch_size: int = 64 + graph_optimization_config: Optional[Dict[str, Any]] = None """ - Maximum Batch Size for Cuda Graph Capture - NOTE: Now only support to capture continuous batch size, - Example: - max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64]. + Configuration for graph optimization backend execution. """ enable_logprob: bool = False @@ -399,21 +393,14 @@ class EngineArgs: "default is None. The priority of this configuration "\ "is lower than that of the config file. " \ "More complex quantization methods need to be configured via the config file.") - - model_group.add_argument( - "--enable-static-graph-inference", - action='store_true', - default=EngineArgs.enable_static_graph_inference, - help="Whether to use static mode; if enabled, " \ - "'paddle.to_static' will be used to convert dynamic to static.") model_group.add_argument("--use-cudagraph", action='store_true', default=EngineArgs.use_cudagraph, help="Flags to enable cuda graph.") - model_group.add_argument("--max-capture-batch-size", - type=int, - default=EngineArgs.max_capture_batch_size, - help="Maximum of Batch Size for Warm Up.") + model_group.add_argument("--graph-optimization-config", + type=json.loads, + default=EngineArgs.graph_optimization_config, + help="") model_group.add_argument("--guided-decoding-backend", type=str, default=EngineArgs.guided_decoding_backend, @@ -757,6 +744,15 @@ class EngineArgs: enable_custom_all_reduce=self.enable_custom_all_reduce ) + def create_graph_optimization_config(self) -> GraphOptimizationConfig: + """ + Create and retuan a GraphOptimizationConfig object based on the current settings. + """ + if self.graph_optimization_config is not None: + return GraphOptimizationConfig(**self.graph_optimization_config) + else: + return GraphOptimizationConfig() + def create_engine_config(self) -> Config: """ Create and return a Config object based on the current settings. @@ -771,8 +767,9 @@ class EngineArgs: else: self.max_num_batched_tokens = self.max_model_len scheduler_cfg = self.create_scheduler_config() - speculative_cfg = self.create_speculative_config() + graph_opt_cfg = self.create_graph_optimization_config() + graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) assert not (self.use_cudagraph and self.enable_prefix_caching), \ "Prefix caching cannot be used with CUDA graph" @@ -804,9 +801,7 @@ class EngineArgs: max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - enable_static_graph_inference=self.enable_static_graph_inference, - use_cudagraph=self.use_cudagraph, - max_capture_batch_size=self.max_capture_batch_size, + graph_optimization_config=graph_opt_cfg, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, enable_custom_all_reduce=self.enable_custom_all_reduce, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index ee7a8c367..01ede8587 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -430,6 +430,77 @@ class SpeculativeConfig: llm_logger.info( "=============================================================") + def __str__(self) -> str: + return self.to_json_string() + + +class GraphOptimizationConfig: + def __init__( + self, + graph_opt_level: Optional[int] = 0, + use_cudagraph: Optional[bool] = None, + cudagraph_capture_sizes: Optional[List[int]] = None, + **kwargs + ): + """ + Graph Optimization Configuration class. + + Attributes: + graph_opt_level: Compute graph optimization level + use_cudagraph: Use CUDA Graph or not + cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph + """ + self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs) + + self.graph_opt_level = graph_opt_level + self.use_cudagraph = use_cudagraph + self.cudagraph_capture_sizes = cudagraph_capture_sizes + + def to_json_string(self): + """ + Convert speculative_config to json string. + """ + return json.dumps({ + key: value + for key, value in self.__dict__.items() + }) + + def __str__(self) -> str: + return self.to_json_string() + + def check_legality_parameters( + self, + graph_opt_level: Optional[int] = None, + use_cudagraph: Optional[bool] = None, + cudagraph_capture_sizes: Optional[List[int]] = None, + **kwargs + ) -> None: + """ Check the legality of parameters passed in from the command line """ + + if graph_opt_level is not None: + assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2." + if use_cudagraph is not None: + assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool." + if cudagraph_capture_sizes is not None: + assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list." + assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list." + + for key, value in kwargs.items(): + raise ValueError(f"Invalid --graph-optimization-config parameter {key}") + + def update_use_cudagraph(self, argument:bool): + """ + Unified user specifies the use_cudagraph parameter through two methods, + '--use-cudagraph' and '--graph-optimization-config' + """ + if self.use_cudagraph is None: + # User only set '--use-cudagraph' + self.use_cudagraph = argument + else: + # User both set '--use-cudagraph' and '--graph-optimization-config' + if self.use_cudagraph is False and argument is True: + raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.") + argument = self.use_cudagraph class ParallelConfig: """ @@ -573,6 +644,7 @@ class Config: max_num_batched_tokens: Optional[int] = None, pod_ips: Optional[List[str]] = None, speculative_config: Optional[Dict[str, Any]] = None, + graph_optimization_config: Optional[Dict[str, Any]] = None, use_warmup: bool = False, engine_worker_queue_port: int = 8002, limit_mm_per_prompt: Optional[Dict[str, Any]] = None, @@ -584,9 +656,6 @@ class Config: max_long_partial_prefills: int = 1, long_prefill_token_threshold: int = 0, reasoning_parser: str = None, - enable_static_graph_inference: bool = False, - use_cudagraph: bool = False, - max_capture_batch_size: int = 64, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, enable_custom_all_reduce: bool = False, @@ -609,6 +678,7 @@ class Config: pod_ips (Optional[List[str]]): List of POD IPs. Default is None. mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None. speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None. + graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None. use_warmup (bool): Flag to use warmup. Default is False. engine_worker_queue_port (int): Engine worker queue port. Default is 8002. enable_mm (bool): Flag to enable multi-modal processing. Default is False. @@ -643,9 +713,7 @@ class Config: self.max_long_partial_prefills = max_long_partial_prefills self.long_prefill_token_threshold = long_prefill_token_threshold self.reasoning_parser = reasoning_parser - self.enable_static_graph_inference = enable_static_graph_inference - self.use_cudagraph = use_cudagraph - self.max_capture_batch_size = max_capture_batch_size + self.graph_optimization_config = graph_optimization_config self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace self.is_master = True diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ae6203aa6..1f9dc9278 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1026,7 +1026,7 @@ class LLMEngine(object): f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" - f" --max_capture_batch_size {self.cfg.max_capture_batch_size}" + f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" f" --load_strategy {self.cfg.model_config.load_strategy}" f" --enable_mm {self.cfg.enable_mm}") @@ -1041,9 +1041,6 @@ class LLMEngine(object): self.cfg.cache_config.enable_chunked_prefill, "do_profile": self.do_profile, "dynamic_load_weight": self.cfg.model_config.dynamic_load_weight, - "enable_static_graph_inference": - self.cfg.enable_static_graph_inference, - "use_cudagraph": self.cfg.use_cudagraph, "disable_any_whitespace": self.cfg.disable_any_whitespace, "enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_logprob": self.cfg.enable_logprob, diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index ae3f092e4..17ab2e9ad 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,11 +17,11 @@ import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Optional -from fastdeploy.model_executor.layers.attention import AttentionBackend +from typing import Optional import paddle - + +from fastdeploy.model_executor.layers.attention import AttentionBackend logger = logging.getLogger(__name__) @@ -64,8 +64,6 @@ class ForwardMeta(): # Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage. step_use_cudagraph: bool = False - # Batch type flag - is_decode_batch: bool = False # Attention backend object attn_backend: AttentionBackend = None diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 53a0e52d6..730a05807 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -68,16 +68,20 @@ class CudaGraphPiecewiseBackend: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_bs=shape) - logger.debug("[CUDA GRAPH] Created all batch size entry ") + logger.info( + f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " + "Created all batch sizes entry." + ) def __call__(self, **kwargs): # Get batch size ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] batch_size = ids_remove_padding.shape[0] padding_batch_size = self.batch_size_to_captured_size[batch_size] - logger.debug(( - f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ", - f"The padded batch size is :{padding_batch_size}")) + logger.debug( + f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, " + f"The padded batch size is :{padding_batch_size}" + ) entry = self.concrete_size_entries.get(padding_batch_size) assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list." @@ -96,10 +100,10 @@ class CudaGraphPiecewiseBackend: for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 entry.runnable(**kwargs) - logger.debug(( - "[CUDA GRAPH] Warm up for batch size ", - f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times" - )) + logger.debug( + f"[CUDA GRAPH] Warm up for batch size {padding_batch_size}, " + f"finished ({n+1}/{entry.num_finished_warmup}) times" + ) # Store input addresses for debug input_addresses = [ diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index a8d2124ae..9f57f4179 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -33,7 +33,8 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) -from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.model_executor.layers.attention.utils import \ + init_rank_and_device_id @dataclass @@ -106,7 +107,7 @@ class AppendAttentionBackend(AttentionBackend): if fd_config.parallel_config.expert_parallel_rank is None: fd_config.parallel_config.expert_parallel_rank = 0 - + self.rank, self.device_id = init_rank_and_device_id(fd_config) def init_attention_metadata(self, forward_meta: ForwardMeta): @@ -134,8 +135,8 @@ class AppendAttentionBackend(AttentionBackend): metadata.kv_batch_ids, metadata.kv_tile_ids_per_batch, metadata.kv_num_blocks, - metadata.decoder_batch_ids, - metadata.decoder_tile_ids_per_batch, + metadata.decoder_batch_ids, # will copy to buffer + metadata.decoder_tile_ids_per_batch, # will copy to buffer metadata.decoder_num_blocks, metadata.max_len_kv, metadata.set_max_lengths, diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 7176d21d4..a662d130a 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -53,9 +53,6 @@ class RolloutModelConfig: enable_expert_parallell: bool = False, ori_vocab_size: int = None, quantization: str = "None", - enable_static_graph_inference: bool = False, - use_cudagraph: bool = False, - max_capture_batch_size: int = 64, guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, enable_logprob: bool = False, @@ -95,9 +92,6 @@ class RolloutModelConfig: self.enable_expert_parallell = enable_expert_parallell self.ori_vocab_size = ori_vocab_size self.quantization = quantization - self.enable_static_graph_inference = enable_static_graph_inference - self.use_cudagraph = use_cudagraph - self.max_capture_batch_size = max_capture_batch_size self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 7a81f9600..0316779b4 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -335,66 +335,44 @@ def download_model(url, output_dir, temp_tar): class FlexibleArgumentParser(argparse.ArgumentParser): """ - 扩展 argparse.ArgumentParser,支持从 YAML 文件加载参数。 + Extend argparse.ArgumentParser to support loading parameters from YAML files. """ def __init__(self, *args, config_arg='--config', sep='_', **kwargs): super().__init__(*args, **kwargs) - self.sep = sep # 用于展平嵌套字典的分隔符 - # 创建临时解析器,仅用于解析 --config 参数 + self.sep = sep + + # Create parser to prase yaml file self.tmp_parser = argparse.ArgumentParser(add_help=False) self.tmp_parser.add_argument(config_arg, type=str, help='Path to YAML config file') def parse_args(self, args=None, namespace=None): - # 使用临时解析器解析出 --config 参数 tmp_ns, remaining_args = self.tmp_parser.parse_known_args(args=args) config_path = tmp_ns.config - # 加载 YAML 文件并展平嵌套结构 config = {} if config_path: with open(config_path, 'r') as f: loaded_config = yaml.safe_load(f) - config = self._flatten_dict(loaded_config) + config = loaded_config - # 获取所有已定义参数的 dest 名称 + # Get declared parameters defined_dests = {action.dest for action in self._actions} - - # 过滤出已定义的参数 filtered_config = { k: v for k, v in config.items() if k in defined_dests } - # 创建或使用现有的命名空间对象 + # Set parameters if namespace is None: namespace = argparse.Namespace() - - # 将配置参数设置到命名空间 for key, value in filtered_config.items(): setattr(namespace, key, value) - # 解析剩余参数并覆盖默认值 return super().parse_args(args=remaining_args, namespace=namespace) - def _flatten_dict(self, d): - """将嵌套字典展平为单层字典,键由分隔符连接""" - - def _flatten(d, parent_key=''): - items = [] - for k, v in d.items(): - new_key = f"{parent_key}{self.sep}{k}" if parent_key else k - if isinstance(v, dict): - items.extend(_flatten(v, new_key).items()) - else: - items.append((new_key, v)) - return dict(items) - - return _flatten(d) - - def resolve_obj_from_strname(strname: str): module_name, obj_name = strname.rsplit(".", 1) module = importlib.import_module(module_name) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 5756bdbe3..29c6f189c 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -748,10 +748,6 @@ class GCUModelRunner(ModelRunnerBase): # 3. Prepare lora # 4. Run model - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing - self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) @@ -979,10 +975,6 @@ class GCUModelRunner(ModelRunnerBase): # 2. Padding inputs for cuda grph # 3. Execute model - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch - self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 658863906..22336f28d 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -417,7 +417,10 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 self.share_inputs["step_idx"][idx:idx + 1] = 0 self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len + self.share_inputs["min_dec_len"][idx:idx + 1] = max_dec_len self.share_inputs["stop_flags"][idx:idx + 1] = False + self.share_inputs["top_p"][idx:idx + 1] = 0.0 + self.share_inputs["temperature"][idx:idx + 1] = 1 self.share_inputs["first_token_ids"][ idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] @@ -759,6 +762,11 @@ class GPUModelRunner(ModelRunnerBase): caches=self.share_inputs["caches"] ) + # Update Batch type for cuda graph + # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch + is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0) + self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) @@ -850,6 +858,7 @@ class GPUModelRunner(ModelRunnerBase): Args: num_tokens: expected_decode_len: Expected number of tokens generated + in_capturing: Is cuda graph in capturing state """ self._dummy_prefill_inputs(num_tokens=num_tokens, batch_size=batch_size, @@ -864,17 +873,16 @@ class GPUModelRunner(ModelRunnerBase): # 1. Initialize forward meta and attention meta data self._prepare_inputs() - # 2. Prepare lora + # 2. Padding inputs for cuda graph + self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph + self.padding_cudagraph_inputs() # 3. Run model - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing - self.forward_meta.is_decode_batch = is_decode_batch if self.enable_mm: - hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"], + model_output = self.model(self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta) + hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], @@ -1113,9 +1121,7 @@ class GPUModelRunner(ModelRunnerBase): We plan to replace it with 'ModelForwardBatch'. intermediate_tensors: """ - # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. - # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, - # when there is data on other runner, the current runner is required to execute part of the model. + # NOTE(wufeisheng): For Expert Parallelism if not self.not_need_stop(): self._execute_empty_input() return None @@ -1126,18 +1132,14 @@ class GPUModelRunner(ModelRunnerBase): self.sampler.pre_process(skip_idx_list) # 2. Padding inputs for cuda graph + self.padding_cudagraph_inputs() # 3. Execute model - # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch - self.forward_meta.is_decode_batch = is_decode_batch - if self.enable_mm: - hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"], + model_output = self.model(self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta) + hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], @@ -1399,6 +1401,18 @@ class GPUModelRunner(ModelRunnerBase): self.dynamic_weight_manager._log_memory( "dynamic weight manager update all memory") + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # TODO(gongshaotian): Use more efficient implementation + if self.forward_meta.step_use_cudagraph: + num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum() + for i in range(1, num_empty_batch + 1): + self.forward_meta.decoder_batch_ids[-i] = 0 + self.forward_meta.decoder_tile_ids_per_batch[-i] = 0 + def _init_image_preprocess(self) -> None: processor = DataProcessor( tokenizer_name=self.tokenizer_path, diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index b0caa7d3b..cd31e65ad 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -715,10 +715,6 @@ class IluvatarModelRunner(ModelRunnerBase): # 3. Prepare lora # 4. Run model - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing - self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) @@ -939,11 +935,6 @@ class IluvatarModelRunner(ModelRunnerBase): # 2. Padding inputs for cuda grph # 3. Execute model - # TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch - is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] - > 1).sum() > 0) - self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch - self.forward_meta.is_decode_batch = is_decode_batch model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 1a946ecea..d180b1133 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -14,6 +14,7 @@ # limitations under the License. """ import argparse +import json import time from typing import List @@ -516,18 +517,11 @@ def parse_args(): "default is None. The priority of this configuration "\ "is lower than that of the config file. " \ "More complex quantization methods need to be configured via the config file.") - parser.add_argument("--enable_static_graph_inference", - action='store_true', - help="Whether to use static mode; if enabled, " \ - "'paddle.to_static' will be used to convert dynamic to static.") - parser.add_argument("--use_cudagraph", - action='store_true', - help="Flags to enable cuda graph.") - parser.add_argument("--max_capture_batch_size", - type=int, - default=64, - help="Maximum Batch Size for Cuda Graph Capture. " \ - "If max_capture_batch_size set 64, FastDeploy will capture batch size in [1, 64]") + parser.add_argument("--graph_optimiaztion_config", + type=json.loads, + default=None, + help=" Configation of Graph optimization backend. " + ) parser.add_argument("--guided_decoding_backend", type=str, default="off", @@ -579,9 +573,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: load_config = LoadConfig(vars(args)) graph_opt_config = GraphOptimizationConfig( - args.enable_static_graph_inference, - args.max_capture_batch_size, - vars(args)) + use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"], + graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"], + cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"] + ) # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank diff --git a/test/ci_use/EB_Lite/test_EB_Lite_serving.py b/test/ci_use/EB_Lite/test_EB_Lite_serving.py index d0b9e6dd6..8f659fb77 100644 --- a/test/ci_use/EB_Lite/test_EB_Lite_serving.py +++ b/test/ci_use/EB_Lite/test_EB_Lite_serving.py @@ -91,7 +91,7 @@ def setup_and_run_server(): "--max-num-seqs", "128", "--quantization", "wint4", "--use-cudagraph", - "--max-capture-batch-size", "1" + "--graph-optimization-config", '{"cudagraph_capture_sizes": [1]}' ] # Start subprocess in new process group