mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Unify server-side and model-side Config (Part2) (#3035)
* merge speculative and graph opt conifg * add attr
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
@@ -24,10 +25,12 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import check_unified_ckpt, get_logger
|
||||
|
||||
logger = get_logger("config", "config.log")
|
||||
|
||||
TaskOption = Literal["generate"]
|
||||
|
||||
|
||||
class MoEPhase:
|
||||
"""
|
||||
@@ -269,6 +272,7 @@ class SpeculativeConfig:
|
||||
# This ensures that the specified simulation acceptance rate is not affected.
|
||||
self.benchmark_mode: bool = False
|
||||
|
||||
self.num_extra_cache_layer = 0
|
||||
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
|
||||
# We temperately add the name map here and will delete it in future.
|
||||
name_map = {
|
||||
@@ -284,6 +288,69 @@ class SpeculativeConfig:
|
||||
if key == "speculative_benchmark_mode":
|
||||
value = True if value.lower() == "true" else False
|
||||
setattr(self, name_map[key], value)
|
||||
self.read_model_config()
|
||||
self.reset()
|
||||
|
||||
def read_model_config(self):
|
||||
"""
|
||||
Read configuration from file.
|
||||
"""
|
||||
self.model_config = {}
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
|
||||
if self.model_name_or_path is None:
|
||||
return
|
||||
|
||||
self.config_path = os.path.join(self.model_name_or_path, "config.json")
|
||||
if os.path.exists(self.config_path):
|
||||
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset configuration.
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key=None, default=None):
|
||||
if key is not None and key in cls.model_config:
|
||||
setattr(cls, value_name, cls.model_config[key])
|
||||
elif getattr(cls, value_name, None) is None:
|
||||
setattr(cls, value_name, default)
|
||||
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
# NOTE(liuzichang): We will support multi-layer in future
|
||||
if self.method in ["mtp"]:
|
||||
self.num_extra_cache_layer = 1
|
||||
|
||||
def enabled_speculative_decoding(self):
|
||||
"""
|
||||
Check if speculative decoding is enabled.
|
||||
"""
|
||||
if self.method is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
logger.info("Speculative Decoding Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
@@ -301,18 +368,21 @@ class DeviceConfig:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphOptimizationConfig:
|
||||
"""
|
||||
Configuration for compute graph level optimization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
"""The Top-level graph optimization contral corresponds to different backends.
|
||||
- 0: dyncmic graph
|
||||
- 1: static graph
|
||||
- 2: static graph + cinn compilation backend
|
||||
"""
|
||||
graph_opt_level: int = 0
|
||||
self.graph_opt_level: int = 0
|
||||
|
||||
# CUDA Graph Config
|
||||
""" Whether to use cudagraph.
|
||||
@@ -323,20 +393,20 @@ class GraphOptimizationConfig:
|
||||
- With dyncmic graph backend: ...
|
||||
- With static grpah backend: WIP
|
||||
"""
|
||||
sot_warmup_sizes: Optional[list[int]] = field(default_factory=list)
|
||||
self.sot_warmup_sizes: Optional[list[int]] = []
|
||||
""" Number of warmup runs for SOT warmup. """
|
||||
use_cudagraph: bool = False
|
||||
self.use_cudagraph: bool = False
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from llm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
self.cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
""" Number of warmup runs for cudagraph. """
|
||||
cudagraph_num_of_warmups: int = 2
|
||||
self.cudagraph_num_of_warmups: int = 2
|
||||
"""Whether to copy input tensors for cudagraph.
|
||||
If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True."""
|
||||
cudagraph_copy_inputs: bool = False
|
||||
self.cudagraph_copy_inputs: bool = False
|
||||
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
|
||||
CudaGraphBackend will split these operations from the static graph.
|
||||
Example usage:
|
||||
@@ -346,15 +416,21 @@ class GraphOptimizationConfig:
|
||||
can manually split the model into multiple layers and apply the @support_graph_optimization decorator
|
||||
only to the layer where CUDA graph functionality is required.
|
||||
"""
|
||||
cudagraph_splitting_ops: list[str] = field(default_factory=list)
|
||||
self.cudagraph_splitting_ops: list[str] = []
|
||||
""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
full_cuda_graph: bool = True
|
||||
self.full_cuda_graph: bool = True
|
||||
|
||||
max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||
batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore
|
||||
self.max_capture_size: int = None
|
||||
self.batch_size_to_captured_size: dict[int, int] = None
|
||||
# CINN Config ...
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
self.check_legality_parameters()
|
||||
|
||||
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
|
||||
"""
|
||||
@@ -401,6 +477,54 @@ class GraphOptimizationConfig:
|
||||
draft_capture_sizes.append(max_num_seqs)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
|
||||
if self.graph_opt_level is not None:
|
||||
assert self.graph_opt_level in [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
|
||||
if self.use_cudagraph is not None:
|
||||
assert (
|
||||
type(self.use_cudagraph) is bool
|
||||
), "In graph optimization config, type of use_cudagraph must is bool."
|
||||
if self.cudagraph_capture_sizes is not None:
|
||||
assert (
|
||||
type(self.cudagraph_capture_sizes) is list
|
||||
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
|
||||
assert (
|
||||
len(self.cudagraph_capture_sizes) > 0
|
||||
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
|
||||
|
||||
def update_use_cudagraph(self, argument: bool):
|
||||
"""
|
||||
Unified user specifies the use_cudagraph parameter through two methods,
|
||||
'--use-cudagraph' and '--graph-optimization-config'
|
||||
"""
|
||||
if self.use_cudagraph is None:
|
||||
# User only set '--use-cudagraph'
|
||||
self.use_cudagraph = argument
|
||||
else:
|
||||
# User both set '--use-cudagraph' and '--graph-optimization-config'
|
||||
if self.use_cudagraph is False and argument is True:
|
||||
raise ValueError(
|
||||
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
|
||||
)
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class LoadConfig:
|
||||
"""
|
||||
|
@@ -19,15 +19,13 @@ from dataclasses import asdict, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.config import CacheConfig
|
||||
from fastdeploy.engine.config import (
|
||||
Config,
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
GraphOptimizationConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SpeculativeConfig,
|
||||
TaskOption,
|
||||
)
|
||||
from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
from fastdeploy.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -772,10 +770,12 @@ class EngineArgs:
|
||||
|
||||
def create_speculative_config(self) -> SpeculativeConfig:
|
||||
""" """
|
||||
speculative_args = asdict(self)
|
||||
if self.speculative_config is not None:
|
||||
return SpeculativeConfig(**self.speculative_config)
|
||||
else:
|
||||
return SpeculativeConfig()
|
||||
for k, v in self.speculative_config.items():
|
||||
speculative_args[k] = v
|
||||
|
||||
return SpeculativeConfig(speculative_args)
|
||||
|
||||
def create_scheduler_config(self) -> SchedulerConfig:
|
||||
"""
|
||||
@@ -816,10 +816,11 @@ class EngineArgs:
|
||||
"""
|
||||
Create and retuan a GraphOptimizationConfig object based on the current settings.
|
||||
"""
|
||||
graph_optimization_args = asdict(self)
|
||||
if self.graph_optimization_config is not None:
|
||||
return GraphOptimizationConfig(**self.graph_optimization_config)
|
||||
else:
|
||||
return GraphOptimizationConfig()
|
||||
for k, v in self.graph_optimization_config.items():
|
||||
graph_optimization_args[k] = v
|
||||
return GraphOptimizationConfig(graph_optimization_args)
|
||||
|
||||
def create_engine_config(self) -> Config:
|
||||
"""
|
||||
|
@@ -17,7 +17,7 @@ import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import CacheConfig
|
||||
@@ -31,8 +31,6 @@ from fastdeploy.utils import (
|
||||
llm_logger,
|
||||
)
|
||||
|
||||
TaskOption = Literal["generate"]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
@@ -158,188 +156,6 @@ class ModelConfig:
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""
|
||||
Speculative Decoding Configuration class.
|
||||
|
||||
Attributes:
|
||||
method (Optional[str]): Method used for speculative decoding.
|
||||
num_speculative_tokens (int): Maximum draft tokens, default is 1.
|
||||
model_name_or_path (Optional[str]): Path of the model.
|
||||
quantization (str): Quantization method for draft model, default is WINT8.
|
||||
max_model_len: Optional[int]: Maximum model length for draft model.
|
||||
benchmark_mode (bool): Whether to use benchmark mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: Optional[str] = None,
|
||||
num_speculative_tokens: Optional[int] = 1,
|
||||
model: Optional[str] = None,
|
||||
quantization: Optional[str] = "WINT8",
|
||||
max_model_len: Optional[int] = None,
|
||||
benchmark_mode: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name_or_path = model
|
||||
self.method = method
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.quantization = quantization
|
||||
self.max_model_len = max_model_len
|
||||
self.benchmark_mode = benchmark_mode
|
||||
# Fixed now
|
||||
self.num_gpu_block_expand_ratio = 1
|
||||
self.num_extra_cache_layer = 0
|
||||
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
self.read_model_config()
|
||||
self.reset()
|
||||
|
||||
def read_model_config(self):
|
||||
"""
|
||||
Read configuration from file.
|
||||
"""
|
||||
self.model_config = {}
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
|
||||
if self.model_name_or_path is None:
|
||||
return
|
||||
|
||||
self.config_path = os.path.join(self.model_name_or_path, "config.json")
|
||||
if os.path.exists(self.config_path):
|
||||
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset configuration.
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key=None, default=None):
|
||||
if key is not None and key in cls.model_config:
|
||||
setattr(cls, value_name, cls.model_config[key])
|
||||
elif getattr(cls, value_name, None) is None:
|
||||
setattr(cls, value_name, default)
|
||||
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
# NOTE(liuzichang): We will support multi-layer in future
|
||||
if self.method in ["mtp"]:
|
||||
self.num_extra_cache_layer = 1
|
||||
|
||||
def enabled_speculative_decoding(self):
|
||||
"""
|
||||
Check if speculative decoding is enabled.
|
||||
"""
|
||||
if self.method is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Speculative Decoding Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
||||
class GraphOptimizationConfig:
|
||||
def __init__(
|
||||
self,
|
||||
graph_opt_level: Optional[int] = 0,
|
||||
use_cudagraph: Optional[bool] = None,
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None,
|
||||
sot_warmup_sizes: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Graph Optimization Configuration class.
|
||||
|
||||
Attributes:
|
||||
graph_opt_level: Compute graph optimization level
|
||||
use_cudagraph: Use CUDA Graph or not
|
||||
cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph
|
||||
"""
|
||||
self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs)
|
||||
|
||||
self.graph_opt_level = graph_opt_level
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||
self.sot_warmup_sizes = [] if sot_warmup_sizes is None else sot_warmup_sizes
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
graph_opt_level: Optional[int] = None,
|
||||
use_cudagraph: Optional[bool] = None,
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Check the legality of parameters passed in from the command line"""
|
||||
|
||||
if graph_opt_level is not None:
|
||||
assert graph_opt_level in [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
|
||||
if use_cudagraph is not None:
|
||||
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
|
||||
if cudagraph_capture_sizes is not None:
|
||||
assert (
|
||||
type(cudagraph_capture_sizes) is list
|
||||
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
|
||||
assert (
|
||||
len(cudagraph_capture_sizes) > 0
|
||||
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
|
||||
|
||||
for key, value in kwargs.items():
|
||||
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
|
||||
|
||||
def update_use_cudagraph(self, argument: bool):
|
||||
"""
|
||||
Unified user specifies the use_cudagraph parameter through two methods,
|
||||
'--use-cudagraph' and '--graph-optimization-config'
|
||||
"""
|
||||
if self.use_cudagraph is None:
|
||||
# User only set '--use-cudagraph'
|
||||
self.use_cudagraph = argument
|
||||
else:
|
||||
# User both set '--use-cudagraph' and '--graph-optimization-config'
|
||||
if self.use_cudagraph is False and argument is True:
|
||||
raise ValueError(
|
||||
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
|
||||
)
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for parallelism.
|
||||
|
@@ -19,7 +19,7 @@ from typing import Dict, Optional
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.config import SpeculativeConfig
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_iluvatar():
|
||||
|
@@ -630,14 +630,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig()
|
||||
if args.graph_optimization_config is not None:
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
|
||||
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
|
||||
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"],
|
||||
sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"],
|
||||
)
|
||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
|
Reference in New Issue
Block a user