Simplify the Config code (#2770)

* simplify the code

* fix vl

* delete config

* fix

* perfect code

* fix ci

* fix xpu

* fix xpu

* fix server

* resolve conflict

* fix mtp

* resolve conflict

* fix xpu

* fix xpu

* fix vl

* fix log

* fix qwen moe

* fix qwen moe

* fix qwen moe
This commit is contained in:
YuanRisheng
2025-07-14 19:50:05 +08:00
committed by GitHub
parent 2e81792d64
commit 4c7b8bc458
34 changed files with 551 additions and 911 deletions

View File

@@ -21,14 +21,15 @@ from enum import Enum
from typing import Literal, Optional, Union
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.trl import llm_utils
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantConfigBase
from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log")
class MoEPhase(Enum):
"""
The generation phase of the moe.
@@ -37,274 +38,228 @@ class MoEPhase(Enum):
PREFILL = 1
DECODER = 2
PRETRAINED_INIT_CONFIGURATION = {
"rope_theta": 10000.0,
"num_key_value_heads":-1,
"start_layer_index": 0,
"moe_num_shared_experts":0,
"moe_layer_start_index": 0,
"num_max_dispatch_tokens_per_rank":256,
"moe_use_aux_free":False,
"vocab_size": -1,
"use_rope": True,
"hidden_dropout_prob":0.0,
"initializer_range":0.02,
"max_position_embeddings":512,
"quantization_config":None,
"use_recompute_resampler":False,
"use_temporal_conv":True,
"resampler_fuse_rms_norm":False,
"freq_allocation":20,
"tie_word_embeddings":False,
"rms_norm_eps":1e-5,
}
class ModelConfig(PretrainedConfig):
class ModelConfig:
"""
The configuration class to store the configuration of a `LLM`.
"""
max_stop_seqs_num = 5
stop_seqs_max_len = 8
architectures: list[str] = []
# NOTE(gongshaotain): form _load_model_init_val()
top_p = 0.0
temperature = 1.0
rope_theta = 10000.0
penalty_score = 1.0
frequency_score = 0.0
presence_score = 0.0
min_length = 1
def __init__(
self,
vocab_size: int = 100224,
hidden_size: int = 4096,
num_layers: int = 48,
num_attention_heads: int = 32,
num_key_value_heads: Optional[int] = None,
hidden_act: str = "swiglu",
hidden_dropout_prob: float = 0.0,
max_position_embeddings: int = 512,
max_seq_len: int = 512,
initializer_range: float = 0.02,
use_rope=True,
rope_theta: int = 10000,
rope_3d: bool = False,
ori_vocab_size: int | None = None,
moe_layer_start_index: Union[int, list[int], None] = None,
moe_num_experts: Union[int, list[int], None] = None,
moe_layer_end_index: Union[int, list[int], None] = None,
moe_num_shared_experts: int | None = None,
num_hidden_layers: int | None = None,
prefix_name="",
freeze_embedding=False,
rope_head_dim=None,
ffn_hidden_size: Optional[int] = None,
dtype="bfloat16",
start_layer_index: int = 0,
head_dim: Optional[int] = None,
tie_word_embeddings: bool = False,
is_quantized: bool = False,
rms_norm_eps: float = 1e-5,
**kwargs,
args,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
if num_hidden_layers is not None:
self.num_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if head_dim is None:
self.max_stop_seqs_num = 5
self.stop_seqs_max_len = 8
# NOTE(gongshaotain): form _load_model_init_val()
self.top_p = 0.0
self.temperature = 1.0
self.rope_theta = 10000.0
self.penalty_score = 1.0
self.frequency_score = 0.0
self.presence_score = 0.0
self.min_length = 1
self.model_name_or_path = ""
self.im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
self.is_quantized = False
self.max_model_len = 0
self.dtype = ""
self.enable_logprob = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path)
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)
# set attribute from pretrained_config
for key, value in pretrained_config.items():
setattr(self, key, value)
# we need set default value when not exist
for key, value in PRETRAINED_INIT_CONFIGURATION.items():
if not hasattr(self, key):
setattr(self, key, value)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_attention_heads
else:
self.head_dim = head_dim
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_rope = use_rope
self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size
self.max_seq_len = max_seq_len
self.prefix_name = prefix_name
self.freeze_embedding = freeze_embedding
self.rope_head_dim = rope_head_dim
self.moe_layer_start_index = moe_layer_start_index
self.moe_num_experts = moe_num_experts
self.moe_num_shared_experts = moe_num_shared_experts
self.moe_layer_end_index = moe_layer_end_index
self.ffn_hidden_size = ffn_hidden_size
self.rope_3d = rope_3d
self.start_layer_index = start_layer_index
self.dtype = dtype
self.tie_word_embeddings = tie_word_embeddings
self.is_quantized = is_quantized
self.rms_norm_eps = rms_norm_eps
if hasattr(self, "vision_config"):
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
@dataclass
class MoEConfig:
"""
Configuration for MoE.
"""
num_experts: Union[int, list[int], None] = None
top_k: int = 8
moe_intermediate_size: int = -1
num_experts_per_rank: int = -1
num_experts_start_offset: int = -1
self.ori_vocab_size = self.vocab_size
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
self.ori_vocab_size = args["ori_vocab_size"]
moe_num_shared_experts = (0, )
moe_layer_start_index: Union[int, list[int], None] = None
moe_layer_end_index: Union[int, list[int], None] = None
moe_use_aux_free: bool = False
num_max_dispatch_tokens_per_rank = 256
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
block_size = 16 # The block size for processing.
sequence_parallel = False # Whether to enable sequence parallelism.
use_ep = False # Whether to enable Expert Parallelism
moe_phase = MoEPhase.PREFILL # Generation phase
msg_queue_id = 1 # mesage queue id
tensor_parallel_rank = None # TP rank ID
tensor_parallel_degree = None # TP degree
expert_parallel_rank = None # EP rank ID
expert_parallel_degree = None # EP degree
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
"""
From old wersion worker args
TODO(gongshaotian): Reclassify
"""
model_name_or_path: str = "./output"
max_num_seqs: int = 34
# Set default block num for profile run
max_block_num: int = 2000
# block size
block_size: int = 64
# Engine worker queue port
engine_worker_queue_port: int = 9923
# Max model len
max_model_len: int = 3072 # max_seq_len
# cuda visible devices
device_ids: str = "0"
# Input dtype
dtype: str = "bfloat16"
# Encoder's decoder num
enc_dec_block_num: int = 1
# KV cache ratio for input
kv_cache_ratio: float = 0.7
# First token id
first_token_id: int = 1
# Gpu memory utilization
gpu_memory_utilization: float = 0.9
# Process ID of engine
engine_pid: Optional[int] = None
# Do profile or not
do_profile: bool = False
#
pad_token_id: int = -1
#
eos_tokens_lens: int = 2
# Enable chunked prefill
enable_chunked_prefill: str = "store_true"
def __init__(
self,
args,
):
self.sequence_parallel = False # Whether to enable sequence parallelism.
self.use_ep = False # Whether to enable Expert Parallelism
self.moe_phase = MoEPhase.PREFILL # Generation phase
self.msg_queue_id = 1 # mesage queue id
max_num_batched_tokens: int = 2048
# enable prefix cache
enable_prefix_caching = None
# splitwise role
splitwise_role: str = "mixed"
# guided decoding backend
guided_decoding_backend: str = None
# disable any whitespace for guided decoding
disable_any_whitespace: bool = True
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
enable_custom_all_reduce: str = "store_true"
tensor_parallel_rank, tensor_parallel_size = llm_utils.init_dist_env()
self.tensor_parallel_rank = tensor_parallel_rank # TP rank ID
self.tensor_parallel_size = tensor_parallel_size # TP degree
self.expert_parallel_rank = int(tensor_parallel_rank / tensor_parallel_size) # EP rank ID
self.expert_parallel_size = 1 # EP degree
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
"""
From old wersion worker args
TODO(gongshaotian): Reclassify
"""
self.model_name_or_path: str = "./output"
self.max_num_seqs: int = 34
# Set default block num for profile run
self.max_block_num: int = 2000
# block size
self.block_size: int = 64
# Engine worker queue port
self.engine_worker_queue_port: int = 9923
# Max model len
self.max_model_len: int = 3072 # max_seq_len
# cuda visible devices
self.device_ids: str = "0"
# Input dtype
self.dtype: str = "bfloat16"
# Encoder's decoder num
self.enc_dec_block_num: int = 1
# KV cache ratio for input
self.kv_cache_ratio: float = 0.7
# First token id
self.first_token_id: int = 1
# Gpu memory utilization
self.gpu_memory_utilization: float = 0.9
# Process ID of engine
self.engine_pid: Optional[int] = None
# Do profile or not
self.do_profile: bool = False
#
self.pad_token_id: int = -1
#
self.eos_tokens_lens: int = 2
# Enable chunked prefill
self.enable_chunked_prefill: bool = False
self.max_num_batched_tokens: int = 2048
# enable prefix cache
self.enable_prefix_caching = None
# splitwise role
self.splitwise_role: str = "mixed"
# guided decoding backend
self.guided_decoding_backend: str = None
# disable any whitespace for guided decoding
self.disable_any_whitespace: bool = True
self.pod_ip: str = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.use_ep = args["expert_parallel_size"] > 1
if self.splitwise_role == "mixed":
self.moe_phase = MoEPhase.PREFILL
elif self.splitwise_role == "prefill":
self.moe_phase = MoEPhase.PREFILL
elif self.splitwise_role == "decode":
self.moe_phase = MoEPhase.DECODER
else:
raise NotImplementedError
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False
@dataclass
class SpeculativeConfig:
"""
Configuration for speculative decoding.
"""
# speculative method, choose in [None, "ngram_match", "mtp"]
method: Optional[str] = None
# the max length of speculative tokens
num_speculative_tokens: int = 1
# the max length of candidate tokens for speculative method
max_candidate_len: int = 5
# the max length of verify window for speculative method
verify_window: int = 2
# ngram match
max_ngram_size: int = 5
# model for mtp/eagle/draft_model
model_name_or_path: Optional[str] = None
# quantization of model
quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
num_gpu_block_expand_ratio: Optional[float] = 1
# To distinguish the main model and draft model(mtp/eagle/draftmodel)
# ["main", "mtp"]
model_type: Optional[str] = "main"
# TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers.
# A trick method is currently used to enable this sharing.
# This will be replaced with a more standardized solution in the future.
sharing_model = None
# During benchmarking, we need to enforce that the number of accepted tokens is 1.
# This means no tokens from MTP are accepted.
# This ensures that the specified simulation acceptance rate is not affected.
benchmark_mode: bool = False
def __init__(
self,
args,
):
# speculative method, choose in [None, "ngram_match", "mtp"]
self.method: Optional[str] = None
# the max length of speculative tokens
self.num_speculative_tokens: int = 1
# the max length of candidate tokens for speculative method
self.max_candidate_len: int = 5
# the max length of verify window for speculative method
self.verify_window: int = 2
# ngram match
self.max_ngram_size: int = 5
# model for mtp/eagle/draft_model
self.model_name_or_path: Optional[str] = None
# quantization of model
self.quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
self.num_gpu_block_expand_ratio: Optional[float] = 1
# To distinguish the main model and draft model(mtp/eagle/draftmodel)
# ["main", "mtp"]
self.model_type: Optional[str] = "main"
# TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers.
# A trick method is currently used to enable this sharing.
# This will be replaced with a more standardized solution in the future.
self.sharing_model = None
# During benchmarking, we need to enforce that the number of accepted tokens is 1.
# This means no tokens from MTP are accepted.
# This ensures that the specified simulation acceptance rate is not affected.
self.benchmark_mode: bool = False
#TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
#We temperately add the name map here and will delete it in future.
name_map = {"speculative_method": "method",
"speculative_max_draft_token_num": "num_speculative_tokens",
"speculative_model_name_or_path": "model_name_or_path",
"speculative_model_quantization": "quantization",
"speculative_benchmark_mode": "benchmark_mode"}
for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]):
setattr(self, name_map[key], value)
@dataclass
class DeviceConfig:
"""
Configuration for device settings.
"""
device_type = "cuda"
def __init__(
self,
args,
):
self.device_type = "cuda"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
class GraphOptimizationConfig:
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops = Optional[list[str]]
""""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = False
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
def init_with_cudagrpah_size(self,
cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
@@ -338,18 +293,67 @@ class GraphOptimizationConfig:
def __init__(self,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64):
""" """
max_capture_batch_size: int = 64,
args = None):
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
self.graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
self.use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
self.cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
self.cudagraph_splitting_ops = Optional[list[str]]
""""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.full_cuda_graph: bool = False
self.max_capture_size: int = field(default=None, init=False) # type: ignore
self.batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
self.use_cudagraph = use_cudagraph
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if enable_static_graph_inference:
self.graph_opt_level = 1
@dataclass
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
@@ -363,37 +367,39 @@ class LoadConfig:
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading
"""
use_fastsafetensor: bool = False
dynamic_load_weight: bool = False
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
def __init__(
self,
args,
):
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
def __post_init__(self):
if self.load_strategy is not None and not self.dynamic_load_weight:
raise ValueError("Load strategy requires dynamic_load_weight=True")
if self.dynamic_load_weight and self.load_strategy is None:
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
@dataclass
class LoRAConfig:
""" LoRA Config """
pass
@dataclass
class KVCacheConfig:
""" KV Cache Config """
cache_quant_dtype: str = "none"
@dataclass
class DecodingConfig:
"""
Configuration for decoding
"""
pad_token_id = None
def __init__(
self,
args,
):
self.pad_token_id = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
@dataclass
class FDConfig:
@@ -411,7 +417,6 @@ class FDConfig:
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
moe_config: MoEConfig = field(default=None, init=True) # type: ignore
decoding_config: DecodingConfig = field(default=None,
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,