Unify server-side and model-side Config (Part2) (#3035)

* merge speculative and graph opt conifg

* add attr
This commit is contained in:
YuanRisheng
2025-07-28 15:31:48 +08:00
committed by GitHub
parent 776fb03250
commit bddf403576
5 changed files with 186 additions and 252 deletions

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
@@ -24,10 +25,12 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.utils import get_logger
from fastdeploy.utils import check_unified_ckpt, get_logger
logger = get_logger("config", "config.log")
TaskOption = Literal["generate"]
class MoEPhase:
"""
@@ -269,6 +272,7 @@ class SpeculativeConfig:
# This ensures that the specified simulation acceptance rate is not affected.
self.benchmark_mode: bool = False
self.num_extra_cache_layer = 0
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
# We temperately add the name map here and will delete it in future.
name_map = {
@@ -284,6 +288,69 @@ class SpeculativeConfig:
if key == "speculative_benchmark_mode":
value = True if value.lower() == "true" else False
setattr(self, name_map[key], value)
self.read_model_config()
self.reset()
def read_model_config(self):
"""
Read configuration from file.
"""
self.model_config = {}
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
if self.model_name_or_path is None:
return
self.config_path = os.path.join(self.model_name_or_path, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
def reset(self):
"""
Reset configuration.
"""
def reset_value(cls, value_name, key=None, default=None):
if key is not None and key in cls.model_config:
setattr(cls, value_name, cls.model_config[key])
elif getattr(cls, value_name, None) is None:
setattr(cls, value_name, default)
if not self.enabled_speculative_decoding():
return
# NOTE(liuzichang): We will support multi-layer in future
if self.method in ["mtp"]:
self.num_extra_cache_layer = 1
def enabled_speculative_decoding(self):
"""
Check if speculative decoding is enabled.
"""
if self.method is None:
return False
return True
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
print all config
"""
logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def __str__(self) -> str:
return self.to_json_string()
class DeviceConfig:
@@ -301,60 +368,69 @@ class DeviceConfig:
setattr(self, key, value)
@dataclass
class GraphOptimizationConfig:
"""
Configuration for compute graph level optimization.
"""
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
graph_opt_level: int = 0
def __init__(
self,
args,
):
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
self.graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
sot_warmup_sizes: Optional[list[int]] = field(default_factory=list)
""" Number of warmup runs for SOT warmup. """
use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
self.sot_warmup_sizes: Optional[list[int]] = []
""" Number of warmup runs for SOT warmup. """
self.use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
self.cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops: list[str] = field(default_factory=list)
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = True
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
only to the layer where CUDA graph functionality is required.
"""
self.cudagraph_splitting_ops: list[str] = []
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.full_cuda_graph: bool = True
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore
# CINN Config ...
self.max_capture_size: int = None
self.batch_size_to_captured_size: dict[int, int] = None
# CINN Config ...
if args is not None:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.check_legality_parameters()
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
"""
@@ -401,6 +477,54 @@ class GraphOptimizationConfig:
draft_capture_sizes.append(max_num_seqs)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if self.graph_opt_level is not None:
assert self.graph_opt_level in [
0,
1,
2,
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if self.use_cudagraph is not None:
assert (
type(self.use_cudagraph) is bool
), "In graph optimization config, type of use_cudagraph must is bool."
if self.cudagraph_capture_sizes is not None:
assert (
type(self.cudagraph_capture_sizes) is list
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert (
len(self.cudagraph_capture_sizes) > 0
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
def update_use_cudagraph(self, argument: bool):
"""
Unified user specifies the use_cudagraph parameter through two methods,
'--use-cudagraph' and '--graph-optimization-config'
"""
if self.use_cudagraph is None:
# User only set '--use-cudagraph'
self.use_cudagraph = argument
else:
# User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True:
raise ValueError(
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
)
argument = self.use_cudagraph
class LoadConfig:
"""

View File

@@ -19,15 +19,13 @@ from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
from fastdeploy.config import CacheConfig
from fastdeploy.engine.config import (
Config,
from fastdeploy.config import (
CacheConfig,
GraphOptimizationConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
TaskOption,
)
from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig
from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import FlexibleArgumentParser
@@ -772,10 +770,12 @@ class EngineArgs:
def create_speculative_config(self) -> SpeculativeConfig:
""" """
speculative_args = asdict(self)
if self.speculative_config is not None:
return SpeculativeConfig(**self.speculative_config)
else:
return SpeculativeConfig()
for k, v in self.speculative_config.items():
speculative_args[k] = v
return SpeculativeConfig(speculative_args)
def create_scheduler_config(self) -> SchedulerConfig:
"""
@@ -816,10 +816,11 @@ class EngineArgs:
"""
Create and retuan a GraphOptimizationConfig object based on the current settings.
"""
graph_optimization_args = asdict(self)
if self.graph_optimization_config is not None:
return GraphOptimizationConfig(**self.graph_optimization_config)
else:
return GraphOptimizationConfig()
for k, v in self.graph_optimization_config.items():
graph_optimization_args[k] = v
return GraphOptimizationConfig(graph_optimization_args)
def create_engine_config(self) -> Config:
"""

View File

@@ -17,7 +17,7 @@ import json
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Optional
from fastdeploy import envs
from fastdeploy.config import CacheConfig
@@ -31,8 +31,6 @@ from fastdeploy.utils import (
llm_logger,
)
TaskOption = Literal["generate"]
class ModelConfig:
"""
@@ -158,188 +156,6 @@ class ModelConfig:
llm_logger.info("=============================================================")
class SpeculativeConfig:
"""
Speculative Decoding Configuration class.
Attributes:
method (Optional[str]): Method used for speculative decoding.
num_speculative_tokens (int): Maximum draft tokens, default is 1.
model_name_or_path (Optional[str]): Path of the model.
quantization (str): Quantization method for draft model, default is WINT8.
max_model_len: Optional[int]: Maximum model length for draft model.
benchmark_mode (bool): Whether to use benchmark mode.
"""
def __init__(
self,
method: Optional[str] = None,
num_speculative_tokens: Optional[int] = 1,
model: Optional[str] = None,
quantization: Optional[str] = "WINT8",
max_model_len: Optional[int] = None,
benchmark_mode: bool = False,
**kwargs,
):
self.model_name_or_path = model
self.method = method
self.num_speculative_tokens = num_speculative_tokens
self.quantization = quantization
self.max_model_len = max_model_len
self.benchmark_mode = benchmark_mode
# Fixed now
self.num_gpu_block_expand_ratio = 1
self.num_extra_cache_layer = 0
for key, value in kwargs.items():
try:
setattr(self, key, value)
except Exception:
continue
self.read_model_config()
self.reset()
def read_model_config(self):
"""
Read configuration from file.
"""
self.model_config = {}
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
if self.model_name_or_path is None:
return
self.config_path = os.path.join(self.model_name_or_path, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
def reset(self):
"""
Reset configuration.
"""
def reset_value(cls, value_name, key=None, default=None):
if key is not None and key in cls.model_config:
setattr(cls, value_name, cls.model_config[key])
elif getattr(cls, value_name, None) is None:
setattr(cls, value_name, default)
if not self.enabled_speculative_decoding():
return
# NOTE(liuzichang): We will support multi-layer in future
if self.method in ["mtp"]:
self.num_extra_cache_layer = 1
def enabled_speculative_decoding(self):
"""
Check if speculative decoding is enabled.
"""
if self.method is None:
return False
return True
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
print all config
"""
llm_logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
def __str__(self) -> str:
return self.to_json_string()
class GraphOptimizationConfig:
def __init__(
self,
graph_opt_level: Optional[int] = 0,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
sot_warmup_sizes: Optional[List[int]] = None,
**kwargs,
):
"""
Graph Optimization Configuration class.
Attributes:
graph_opt_level: Compute graph optimization level
use_cudagraph: Use CUDA Graph or not
cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph
"""
self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs)
self.graph_opt_level = graph_opt_level
self.use_cudagraph = use_cudagraph
self.cudagraph_capture_sizes = cudagraph_capture_sizes
self.sot_warmup_sizes = [] if sot_warmup_sizes is None else sot_warmup_sizes
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
graph_opt_level: Optional[int] = None,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if graph_opt_level is not None:
assert graph_opt_level in [
0,
1,
2,
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if use_cudagraph is not None:
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
if cudagraph_capture_sizes is not None:
assert (
type(cudagraph_capture_sizes) is list
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert (
len(cudagraph_capture_sizes) > 0
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
for key, value in kwargs.items():
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
def update_use_cudagraph(self, argument: bool):
"""
Unified user specifies the use_cudagraph parameter through two methods,
'--use-cudagraph' and '--graph-optimization-config'
"""
if self.use_cudagraph is None:
# User only set '--use-cudagraph'
self.use_cudagraph = argument
else:
# User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True:
raise ValueError(
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
)
argument = self.use_cudagraph
class ParallelConfig:
"""
Configuration for parallelism.

View File

@@ -19,7 +19,7 @@ from typing import Dict, Optional
import paddle
from fastdeploy import envs
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.config import SpeculativeConfig
from fastdeploy.platforms import current_platform
if current_platform.is_iluvatar():

View File

@@ -630,14 +630,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
load_config = LoadConfig(vars(args))
graph_opt_config = GraphOptimizationConfig()
if args.graph_optimization_config is not None:
graph_opt_config = GraphOptimizationConfig(
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"],
sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"],
)
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank