diff --git a/fastdeploy/config.py b/fastdeploy/config.py index a7744b3ba..48a45f41e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json import os from dataclasses import dataclass, field from typing import Literal, Optional @@ -24,10 +25,12 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase -from fastdeploy.utils import get_logger +from fastdeploy.utils import check_unified_ckpt, get_logger logger = get_logger("config", "config.log") +TaskOption = Literal["generate"] + class MoEPhase: """ @@ -269,6 +272,7 @@ class SpeculativeConfig: # This ensures that the specified simulation acceptance rate is not affected. self.benchmark_mode: bool = False + self.num_extra_cache_layer = 0 # TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig. # We temperately add the name map here and will delete it in future. name_map = { @@ -284,6 +288,69 @@ class SpeculativeConfig: if key == "speculative_benchmark_mode": value = True if value.lower() == "true" else False setattr(self, name_map[key], value) + self.read_model_config() + self.reset() + + def read_model_config(self): + """ + Read configuration from file. + """ + self.model_config = {} + if not self.enabled_speculative_decoding(): + return + + self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path) + if self.model_name_or_path is None: + return + + self.config_path = os.path.join(self.model_name_or_path, "config.json") + if os.path.exists(self.config_path): + self.model_config = json.load(open(self.config_path, "r", encoding="utf-8")) + + def reset(self): + """ + Reset configuration. + """ + + def reset_value(cls, value_name, key=None, default=None): + if key is not None and key in cls.model_config: + setattr(cls, value_name, cls.model_config[key]) + elif getattr(cls, value_name, None) is None: + setattr(cls, value_name, default) + + if not self.enabled_speculative_decoding(): + return + + # NOTE(liuzichang): We will support multi-layer in future + if self.method in ["mtp"]: + self.num_extra_cache_layer = 1 + + def enabled_speculative_decoding(self): + """ + Check if speculative decoding is enabled. + """ + if self.method is None: + return False + return True + + def to_json_string(self): + """ + Convert speculative_config to json string. + """ + return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) + + def print(self): + """ + print all config + + """ + logger.info("Speculative Decoding Configuration Information :") + for k, v in self.__dict__.items(): + logger.info("{:<20}:{:<6}{}".format(k, "", v)) + logger.info("=============================================================") + + def __str__(self) -> str: + return self.to_json_string() class DeviceConfig: @@ -301,60 +368,69 @@ class DeviceConfig: setattr(self, key, value) -@dataclass class GraphOptimizationConfig: """ Configuration for compute graph level optimization. """ - """The Top-level graph optimization contral corresponds to different backends. - - 0: dyncmic graph - - 1: static graph - - 2: static graph + cinn compilation backend - """ - graph_opt_level: int = 0 + def __init__( + self, + args, + ): + """The Top-level graph optimization contral corresponds to different backends. + - 0: dyncmic graph + - 1: static graph + - 2: static graph + cinn compilation backend + """ + self.graph_opt_level: int = 0 - # CUDA Graph Config - """ Whether to use cudagraph. - - False: cudagraph is not used. - - True: cudagraph is used. - It requires that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - - With dyncmic graph backend: ... - - With static grpah backend: WIP - """ - sot_warmup_sizes: Optional[list[int]] = field(default_factory=list) - """ Number of warmup runs for SOT warmup. """ - use_cudagraph: bool = False - """Sizes to capture cudagraph. - - None (default): capture sizes are inferred from llm config. - - list[int]: capture sizes are specified as given.""" - cudagraph_capture_sizes: Optional[list[int]] = None - """ Number of warmup runs for cudagraph. """ - cudagraph_num_of_warmups: int = 2 - """Whether to copy input tensors for cudagraph. - If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True.""" - cudagraph_copy_inputs: bool = False - """ In static graph, this is an operation list that does not need to be captured by the CUDA graph. - CudaGraphBackend will split these operations from the static graph. - Example usage: - cudagraph_splitting_ops = ["paddle.unified_attention"] + # CUDA Graph Config + """ Whether to use cudagraph. + - False: cudagraph is not used. + - True: cudagraph is used. + It requires that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + - With dyncmic graph backend: ... + - With static grpah backend: WIP + """ + self.sot_warmup_sizes: Optional[list[int]] = [] + """ Number of warmup runs for SOT warmup. """ + self.use_cudagraph: bool = False + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from llm config. + - list[int]: capture sizes are specified as given.""" + self.cudagraph_capture_sizes: Optional[list[int]] = None + """ Number of warmup runs for cudagraph. """ + self.cudagraph_num_of_warmups: int = 2 + """Whether to copy input tensors for cudagraph. + If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True.""" + self.cudagraph_copy_inputs: bool = False + """ In static graph, this is an operation list that does not need to be captured by the CUDA graph. + CudaGraphBackend will split these operations from the static graph. + Example usage: + cudagraph_splitting_ops = ["paddle.unified_attention"] - Note: If want to use subgraph capture functionality in a dynamic graph, - can manually split the model into multiple layers and apply the @support_graph_optimization decorator - only to the layer where CUDA graph functionality is required. - """ - cudagraph_splitting_ops: list[str] = field(default_factory=list) - """ Whether to use a full cuda graph for the entire forward pass rather than - splitting certain operations such as attention into subgraphs. - Thus this flag cannot be used together with splitting_ops.""" - full_cuda_graph: bool = True + Note: If want to use subgraph capture functionality in a dynamic graph, + can manually split the model into multiple layers and apply the @support_graph_optimization decorator + only to the layer where CUDA graph functionality is required. + """ + self.cudagraph_splitting_ops: list[str] = [] + """ Whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. + Thus this flag cannot be used together with splitting_ops.""" + self.full_cuda_graph: bool = True - max_capture_size: int = field(default=None, init=False) # type: ignore - batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore - # CINN Config ... + self.max_capture_size: int = None + self.batch_size_to_captured_size: dict[int, int] = None + # CINN Config ... + if args is not None: + for key, value in args.items(): + if hasattr(self, key): + setattr(self, key, value) + + self.check_legality_parameters() def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: """ @@ -401,6 +477,54 @@ class GraphOptimizationConfig: draft_capture_sizes.append(max_num_seqs) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) + def to_json_string(self): + """ + Convert speculative_config to json string. + """ + return json.dumps({key: value for key, value in self.__dict__.items()}) + + def __str__(self) -> str: + return self.to_json_string() + + def check_legality_parameters( + self, + ) -> None: + """Check the legality of parameters passed in from the command line""" + + if self.graph_opt_level is not None: + assert self.graph_opt_level in [ + 0, + 1, + 2, + ], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2." + if self.use_cudagraph is not None: + assert ( + type(self.use_cudagraph) is bool + ), "In graph optimization config, type of use_cudagraph must is bool." + if self.cudagraph_capture_sizes is not None: + assert ( + type(self.cudagraph_capture_sizes) is list + ), "In graph optimization config, type of cudagraph_capture_sizes must is list." + assert ( + len(self.cudagraph_capture_sizes) > 0 + ), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list." + + def update_use_cudagraph(self, argument: bool): + """ + Unified user specifies the use_cudagraph parameter through two methods, + '--use-cudagraph' and '--graph-optimization-config' + """ + if self.use_cudagraph is None: + # User only set '--use-cudagraph' + self.use_cudagraph = argument + else: + # User both set '--use-cudagraph' and '--graph-optimization-config' + if self.use_cudagraph is False and argument is True: + raise ValueError( + "Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously." + ) + argument = self.use_cudagraph + class LoadConfig: """ diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 7dd8fb1d7..ee74a8a6b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -19,15 +19,13 @@ from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional -from fastdeploy.config import CacheConfig -from fastdeploy.engine.config import ( - Config, +from fastdeploy.config import ( + CacheConfig, GraphOptimizationConfig, - ModelConfig, - ParallelConfig, SpeculativeConfig, TaskOption, ) +from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig from fastdeploy.scheduler.config import SchedulerConfig from fastdeploy.utils import FlexibleArgumentParser @@ -772,10 +770,12 @@ class EngineArgs: def create_speculative_config(self) -> SpeculativeConfig: """ """ + speculative_args = asdict(self) if self.speculative_config is not None: - return SpeculativeConfig(**self.speculative_config) - else: - return SpeculativeConfig() + for k, v in self.speculative_config.items(): + speculative_args[k] = v + + return SpeculativeConfig(speculative_args) def create_scheduler_config(self) -> SchedulerConfig: """ @@ -816,10 +816,11 @@ class EngineArgs: """ Create and retuan a GraphOptimizationConfig object based on the current settings. """ + graph_optimization_args = asdict(self) if self.graph_optimization_config is not None: - return GraphOptimizationConfig(**self.graph_optimization_config) - else: - return GraphOptimizationConfig() + for k, v in self.graph_optimization_config.items(): + graph_optimization_args[k] = v + return GraphOptimizationConfig(graph_optimization_args) def create_engine_config(self) -> Config: """ diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 4ce417b39..a5f241a27 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -17,7 +17,7 @@ import json import os from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional from fastdeploy import envs from fastdeploy.config import CacheConfig @@ -31,8 +31,6 @@ from fastdeploy.utils import ( llm_logger, ) -TaskOption = Literal["generate"] - class ModelConfig: """ @@ -158,188 +156,6 @@ class ModelConfig: llm_logger.info("=============================================================") -class SpeculativeConfig: - """ - Speculative Decoding Configuration class. - - Attributes: - method (Optional[str]): Method used for speculative decoding. - num_speculative_tokens (int): Maximum draft tokens, default is 1. - model_name_or_path (Optional[str]): Path of the model. - quantization (str): Quantization method for draft model, default is WINT8. - max_model_len: Optional[int]: Maximum model length for draft model. - benchmark_mode (bool): Whether to use benchmark mode. - """ - - def __init__( - self, - method: Optional[str] = None, - num_speculative_tokens: Optional[int] = 1, - model: Optional[str] = None, - quantization: Optional[str] = "WINT8", - max_model_len: Optional[int] = None, - benchmark_mode: bool = False, - **kwargs, - ): - self.model_name_or_path = model - self.method = method - self.num_speculative_tokens = num_speculative_tokens - self.quantization = quantization - self.max_model_len = max_model_len - self.benchmark_mode = benchmark_mode - # Fixed now - self.num_gpu_block_expand_ratio = 1 - self.num_extra_cache_layer = 0 - - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except Exception: - continue - - self.read_model_config() - self.reset() - - def read_model_config(self): - """ - Read configuration from file. - """ - self.model_config = {} - if not self.enabled_speculative_decoding(): - return - - self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path) - if self.model_name_or_path is None: - return - - self.config_path = os.path.join(self.model_name_or_path, "config.json") - if os.path.exists(self.config_path): - self.model_config = json.load(open(self.config_path, "r", encoding="utf-8")) - - def reset(self): - """ - Reset configuration. - """ - - def reset_value(cls, value_name, key=None, default=None): - if key is not None and key in cls.model_config: - setattr(cls, value_name, cls.model_config[key]) - elif getattr(cls, value_name, None) is None: - setattr(cls, value_name, default) - - if not self.enabled_speculative_decoding(): - return - - # NOTE(liuzichang): We will support multi-layer in future - if self.method in ["mtp"]: - self.num_extra_cache_layer = 1 - - def enabled_speculative_decoding(self): - """ - Check if speculative decoding is enabled. - """ - if self.method is None: - return False - return True - - def to_json_string(self): - """ - Convert speculative_config to json string. - """ - return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) - - def print(self): - """ - print all config - - """ - llm_logger.info("Speculative Decoding Configuration Information :") - for k, v in self.__dict__.items(): - llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) - llm_logger.info("=============================================================") - - def __str__(self) -> str: - return self.to_json_string() - - -class GraphOptimizationConfig: - def __init__( - self, - graph_opt_level: Optional[int] = 0, - use_cudagraph: Optional[bool] = None, - cudagraph_capture_sizes: Optional[List[int]] = None, - sot_warmup_sizes: Optional[List[int]] = None, - **kwargs, - ): - """ - Graph Optimization Configuration class. - - Attributes: - graph_opt_level: Compute graph optimization level - use_cudagraph: Use CUDA Graph or not - cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph - """ - self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs) - - self.graph_opt_level = graph_opt_level - self.use_cudagraph = use_cudagraph - self.cudagraph_capture_sizes = cudagraph_capture_sizes - self.sot_warmup_sizes = [] if sot_warmup_sizes is None else sot_warmup_sizes - - def to_json_string(self): - """ - Convert speculative_config to json string. - """ - return json.dumps({key: value for key, value in self.__dict__.items()}) - - def __str__(self) -> str: - return self.to_json_string() - - def check_legality_parameters( - self, - graph_opt_level: Optional[int] = None, - use_cudagraph: Optional[bool] = None, - cudagraph_capture_sizes: Optional[List[int]] = None, - **kwargs, - ) -> None: - """Check the legality of parameters passed in from the command line""" - - if graph_opt_level is not None: - assert graph_opt_level in [ - 0, - 1, - 2, - ], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2." - if use_cudagraph is not None: - assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool." - if cudagraph_capture_sizes is not None: - assert ( - type(cudagraph_capture_sizes) is list - ), "In graph optimization config, type of cudagraph_capture_sizes must is list." - assert ( - len(cudagraph_capture_sizes) > 0 - ), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list." - - for key, value in kwargs.items(): - raise ValueError(f"Invalid --graph-optimization-config parameter {key}") - - def update_use_cudagraph(self, argument: bool): - """ - Unified user specifies the use_cudagraph parameter through two methods, - '--use-cudagraph' and '--graph-optimization-config' - """ - if self.use_cudagraph is None: - # User only set '--use-cudagraph' - self.use_cudagraph = argument - else: - # User both set '--use-cudagraph' and '--graph-optimization-config' - if self.use_cudagraph is False and argument is True: - raise ValueError( - "Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously." - ) - argument = self.use_cudagraph - - class ParallelConfig: """ Configuration for parallelism. diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 7d0a2aef7..68d168899 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -19,7 +19,7 @@ from typing import Dict, Optional import paddle from fastdeploy import envs -from fastdeploy.engine.config import SpeculativeConfig +from fastdeploy.config import SpeculativeConfig from fastdeploy.platforms import current_platform if current_platform.is_iluvatar(): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 5a295bb92..108a0b8eb 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -630,14 +630,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: load_config = LoadConfig(vars(args)) - graph_opt_config = GraphOptimizationConfig() - if args.graph_optimization_config is not None: - graph_opt_config = GraphOptimizationConfig( - use_cudagraph=args.graph_optimization_config["use_cudagraph"], - graph_opt_level=args.graph_optimization_config["graph_opt_level"], - cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"], - sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"], - ) + graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config) # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank