Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -16,13 +16,11 @@
from __future__ import annotations
import json
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
import paddle
from paddlenlp.transformers.configuration_utils import PretrainedConfig
from paddleformers.transformers.configuration_utils import PretrainedConfig
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantConfigBase
@@ -30,15 +28,10 @@ from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log")
__all__ = [
"ModelConfig",
]
class GenerationPhase(Enum):
class MoEPhase(Enum):
"""
The generation phase of the model.
The generation phase of the moe.
"""
PREFILL = 1
@@ -49,14 +42,25 @@ class ModelConfig(PretrainedConfig):
"""
The configuration class to store the configuration of a `LLM`.
"""
max_stop_seqs_num = 5
stop_seqs_max_len = 8
model_type = ""
architectures: list[str] = []
# NOTE(gongshaotain): form _load_model_init_val()
top_p = 0.0
temperature = 1.0
rope_theta = 10000.0
rope_scaling = None
penalty_score = 1.0
frequency_score = 0.0
presence_score = 0.0
min_length = 1
def __init__(
self,
vocab_size: int = 100224,
hidden_size: int = 4096,
intermediate_size: Optional[int] = None,
num_layers: int = 48,
num_attention_heads: int = 32,
num_key_value_heads: Optional[int] = None,
@@ -65,90 +69,63 @@ class ModelConfig(PretrainedConfig):
max_position_embeddings: int = 512,
max_seq_len: int = 512,
initializer_range: float = 0.02,
type_vocab_size: int = 4,
use_rope=True,
use_rmsnorm=False,
weight_sharing=True,
weight_sharing_add_bias=False,
sequence_parallel=False,
use_flash_attention=False,
use_fast_ffn: bool = False,
tensor_parallel_output: bool = True,
fused_linear=False,
compression_ratio: float = 1.0,
rope_theta: int = 10000,
rope_3d: bool = False,
ori_vocab_size: int | None = None,
smooth: bool = False,
group_size: int = -1,
tools_version="4.10.0.dev",
system_prompt_version="V1",
moe_layer_start_index: int | None = None,
moe_use_gate_correction_bias: bool | None = None,
moe_layer_end_index: int | None = None,
num_hidden_layers: int | None = None,
prefix_name="",
freeze_embedding=False,
rope_head_dim=None,
base_model_prefix=None,
use_moe=False,
ffn_hidden_size: Optional[int] = None,
dtype=None,
export_model_type: str = "default",
use_stop_seqs: bool = False,
return_all_hidden_states: bool = False,
dtype="bfloat16",
start_layer_index: int = 0,
output_via_mq: bool = True,
generation_phase: GenerationPhase = GenerationPhase.PREFILL,
head_dim: Optional[int] = None,
tie_word_embeddings: bool = False,
is_quantized: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_layers = num_layers
if num_hidden_layers is not None:
self.num_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
if head_dim is None:
self.head_dim = self.hidden_size // self.num_attention_heads
else:
self.head_dim = head_dim
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.type_vocab_size = type_vocab_size
self.use_rope = use_rope
self.use_rmsnorm = use_rmsnorm
self.weight_sharing = weight_sharing
self.weight_sharing_add_bias = weight_sharing_add_bias
self.use_flash_attention = use_flash_attention
self.use_fast_ffn = use_fast_ffn
self.tensor_parallel_output = tensor_parallel_output
self.skip_recompute_ops = dict()
self.fused_linear = fused_linear
self.compression_ratio = compression_ratio
self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size
self.smooth = smooth
self.group_size = group_size
self.max_seq_len = max_seq_len
self.tools_version = tools_version
self.system_prompt_version = system_prompt_version
self.prefix_name = prefix_name
self.freeze_embedding = freeze_embedding
self.rope_head_dim = rope_head_dim
self.use_moe = use_moe
self.base_model_prefix = base_model_prefix
moe_num_experts = kwargs.get("moe_num_experts", 0)
if moe_layer_start_index is not None:
self.moe_layer_start_index = moe_layer_start_index
elif moe_use_gate_correction_bias is not None:
self.moe_use_gate_correction_bias = moe_use_gate_correction_bias
elif moe_num_experts == 0:
self.moe_layer_start_index = self.num_layers
self.moe_num_experts = 0
if moe_layer_end_index is not None:
self.moe_layer_end_index = moe_layer_end_index
self.ffn_hidden_size = ffn_hidden_size
self.rope_3d = rope_3d
self.export_model_type = export_model_type
self.use_stop_seqs = use_stop_seqs
self.return_all_hidden_states = return_all_hidden_states
self.start_layer_index = start_layer_index
self.output_via_mq = output_via_mq
self.dtype = dtype
self.tie_word_embeddings = tie_word_embeddings
self.is_quantized = is_quantized
@dataclass
@@ -156,29 +133,19 @@ class MoEConfig:
"""
Configuration for MoE.
"""
use_moe: bool = False
num_experts: int = -1
top_k = 8
top_k: int = 8
moe_intermediate_size: int = -1
num_experts_per_rank: int = -1
num_experts_start_offset: int = -1
activation = "swiglu"
moe_use_gate_correction_bias = False
moe_every2 = (False, )
moe_num_shared_experts = (0, )
moe_layer_start_index = 0
moe_use_ffn_shared_weight_and_bias = (False, )
moe_group = (False, )
moe_quant_type = "default"
moe_layer_end_index = None
num_max_dispatch_tokens_per_rank = 256
has_multimodality: bool = False
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
moe_tag = ""
@dataclass
@@ -187,27 +154,98 @@ class ParallelConfig:
block_size = 16 # The block size for processing.
sequence_parallel = False # Whether to enable sequence parallelism.
use_ep = False # Whether to enable Expert Parallelism
moe_group = False # Whether to enable moe group
msg_queue_id = None # mesage queue id
use_micro_batch = False # Whether to enable micro batch
tensor_parallel_rank = None # TP rank ID
moe_phase = MoEPhase.PREFILL # Generation phase
msg_queue_id = 1 # mesage queue id
tensor_parallel_rank = None # TP rank ID
tensor_parallel_degree = None # TP degree
mp_size = 1 # mp size
ep_size = 1 # ep size
column_cut = False # (bool, optional): The embedding weight distributed on your gpu cards is divided by row or column. Defaults to False means divide by row. When vocab_size can not be divided by world_size but hidden_size can, we can consider split embedding weight by column.
lm_head_column_cut = False
expert_parallel_rank = None # EP rank ID
expert_parallel_degree = None # EP degree
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
column_cut = False # (bool, optional)
"""
From old wersion worker args
TODO(gongshaotian): Reclassify
"""
model_name_or_path: str = "./output"
max_num_seqs: int = 34
# Set default block num for profile run
max_block_num: int = 2000
# block size
block_size: int = 64
# Engine worker queue port
engine_worker_queue_port: int = 9923
# Max model len
max_model_len: int = 3072 # max_seq_len
# cuda visible devices
device_ids: str = "0"
# Input dtype
dtype: str = "bfloat16"
# Encoder's decoder num
enc_dec_block_num: int = 1
# KV cache ratio for input
kv_cache_ratio: float = 0.7
# First token id
first_token_id: int = 1
# Gpu memory utilization
gpu_memory_utilization: float = 0.9
# Process ID of engine
engine_pid: Optional[int] = None
# Do profile or not
do_profile: bool = False
# Dynamic load weight or not
dynamic_load_weight: bool = False
#
pad_token_id: int = -1
#
eos_tokens_lens: int = 2
# Enable chunked prefill
enable_chunked_prefill: str = "store_true"
"""
- APPEND_ATTN:
"""
attention_backend: str = "APPEND_ATTN"
max_num_batched_tokens: int = 2048
# enable prefix cache
enable_prefix_caching = None
# splitwise role
splitwise_role: str = "mixed"
# guided decoding backend
guided_decoding_backend: str = None
# disable any whitespace for guided decoding
disable_any_whitespace: bool = True
@dataclass
class SpeculativeConfig:
"""
Configuration for speculative decoding.
"""
speculate_method = None # speculate method
speculate_max_draft_token_num = 1 # the max length of draft tokens for speculate method
draft_type = "None" # draft type
is_mtp = False # is mtp
speculate_max_candidate_len = 5 # the max length of candidate tokens for speculate method
speculate_verify_window = 2 # the max length of verify window for speculate method
# speculative method, choose in [None, "ngram_match", "mtp"]
method: Optional[str] = None
# the max length of speculative tokens
num_speculative_tokens: int = 1
# the max length of candidate tokens for speculative method
max_candidate_len: int = 5
# the max length of verify window for speculative method
verify_window: int = 2
# ngram match
max_ngram_size: int = 5
# model for mtp/eagle/draft_model
model_name_or_path: Optional[str] = None
# quantization of model
quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
num_gpu_block_expand_ratio: Optional[float] = 1
# To distinguish the main model and draft model(mtp/eagle/draftmodel)
# ["main", "mtp"]
model_type: Optional[str] = "main"
# TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers.
# A trick method is currently used to enable this sharing.
# This will be replaced with a more standardized solution in the future.
sharing_model = None
@dataclass
@@ -215,58 +253,7 @@ class DeviceConfig:
"""
Configuration for device settings.
"""
@dataclass
class AdditionalConfig:
"""
Configuration for testing, debugging or others
"""
use_fake_parameter = False # use fake parameter for test
ep_just_for_test = True # whether to use ep just for test
fake_server_p = False # whether to use fake server
class WeightKeys:
"""
The parameter keys stored in your model_state.padarams.
"""
def __init__(self, num_layers):
"""
Initialization keys retrive weight from model_state.padarams.
Args:
num_layers (int): Number of layers in the Transformer model.
Returns:
None
"""
self.norm_before_qkv_weight_keys = [None for i in range(num_layers)]
self.norm_before_qkv_bias_keys = [None for i in range(num_layers)]
self.qkv_linear_weight_keys = [None for i in range(num_layers)]
self.qkv_linear_bias_keys = [None for i in range(num_layers)]
self.out_linear_weight_keys = [None for i in range(num_layers)]
self.out_linear_bias_keys = [None for i in range(num_layers)]
self.ffn_layernorm_weight_keys = [None for i in range(num_layers)]
self.ffn_layernorm_bias_keys = [None for i in range(num_layers)]
self.ffn1_weight_keys = [None for i in range(num_layers)]
self.ffn1_bias_keys = [None for i in range(num_layers)]
self.ffn2_weight_keys = [None for i in range(num_layers)]
self.ffn2_bias_keys = [None for i in range(num_layers)]
self.moe_gate_weight_keys = None
self.moe_gate_correction_bias_keys = None
self.moe_ffn1_weight_keys = None
self.moe_ffn2_weight_keys = None
self.moe_ffn1_bias_keys = None
self.moe_ffn2_bias_keys = None
self.moe_ffn1_weight_scale_key = None
self.moe_ffn2_weight_scale_key = None
self.moe_ffn1_expert_in_scale_key = None
self.moe_ffn2_expert_in_scale_key = None
device_type = "cuda"
class GraphOptimizationConfig:
@@ -279,7 +266,7 @@ class GraphOptimizationConfig:
# CUDA Graph Config
""" Whether to use cudagraph.
- Fasle: cudagraph is not used.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
@@ -340,9 +327,7 @@ class GraphOptimizationConfig:
0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = [
0 for i in range(self.max_capture_size + 1)
]
self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
@@ -353,91 +338,25 @@ class GraphOptimizationConfig:
self.batch_size_to_captured_size[
self.max_capture_size] = self.max_capture_size
def __init__(self,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64):
""" """
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
self.use_cudagraph = use_cudagraph
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if enable_static_graph_inference:
self.graph_opt_level = 1
@dataclass
class LoadConfig:
"""
Configuration for loading parameter
"""
model_path: str = None # The path to the model file.
weight_keys: Optional[
WeightKeys] = None # Keys stored in your model, which is used to retrieve weights from the state dict.
scale_dir: str = None # The directory where the scale file is located.
act_scales = None
bias_keys = None
def _post_init(self, model_config):
if self.weight_keys:
self.weight_keys_mapping = self._create_weight_key_by_layer_name(
model_config)
else:
self.weight_keys_mapping = {}
self.quant_scale_mapping = self._create_quant_scale_mapping(
model_config)
def _create_weight_key_by_layer_name(self, model_config) -> dict:
mapping = {}
weight_keys = self.weight_keys
num_layers = model_config.num_layers
for i in range(num_layers):
if i == 0:
layer_name = f"{model_config.base_model_prefix}.decoder.layers.0.norm1"
mapping[layer_name] = weight_keys.norm_before_qkv_weight_keys[
0]
if i < num_layers:
layer_name = f"{model_config.base_model_prefix}.decoder.layers.{i}.norm2"
mapping[layer_name] = weight_keys.ffn_layernorm_weight_keys[i]
for i in range(num_layers - 1):
layer_name = f"{model_config.base_model_prefix}.decoder.layers.{i+1}.norm1"
mapping[layer_name] = weight_keys.norm_before_qkv_weight_keys[i +
1]
layer_name = f"{model_config.base_model_prefix}.decoder.norm"
if not model_config.use_moe:
mapping[
layer_name] = f"{model_config.base_model_prefix}.decoder.norm.weight"
else:
mapping[layer_name] = "ernie.norm.weight"
layer_name = f"{model_config.base_model_prefix}.e_norm"
mapping[layer_name] = f"{model_config.base_model_prefix}.e_norm.weight"
layer_name = f"{model_config.base_model_prefix}.h_norm"
mapping[layer_name] = f"{model_config.base_model_prefix}.h_norm.weight"
return mapping
def _create_quant_scale_mapping(self, model_config) -> dict:
mapping = {}
act_scales = self.act_scales
num_layers = model_config.num_layers
for i in range(num_layers):
if i == 0:
layer_name = f"{model_config.base_model_prefix}.decoder.layers.0.norm1"
mapping[layer_name] = act_scales.get(
f"{model_config.base_model_prefix}.decoder.layers.0.self_attn.qkv_proj.activation_quanter",
-1)
if i < num_layers:
layer_name = f"{model_config.base_model_prefix}.decoder.layers.{i}.norm2"
mapping[layer_name] = act_scales.get(
f"{model_config.base_model_prefix}.decoder.layers.{i}.linear1.activation_quanter",
-1)
for i in range(num_layers - 1):
layer_name = f"{model_config.base_model_prefix}.decoder.layers.{i+1}.norm1"
mapping[layer_name] = act_scales.get(
f"{model_config.base_model_prefix}.decoder.layers.{i + 1}.self_attn.qkv_proj.activation_quanter",
-1)
return mapping
def get_weight_key_by_layer_name(self, layer_name: str) -> Optional[str]:
return self.weight_keys_mapping.get(layer_name)
def get_quant_scale_by_layer_name(self, layer_name: str) -> Optional[int]:
return self.quant_scale_mapping.get(layer_name)
pass
@dataclass
@@ -446,52 +365,26 @@ class LoRAConfig:
pass
@dataclass
class SchedulerConfig:
""" Scheduler Config """
pass
@dataclass
class KVCacheConfig:
""" KV Cache Config """
block_size: int = 0
enc_dec_block_num: int = 2
kv_cache_ratio: float = 0.75
dtype: str = 'bfloat16'
kvcache_quant_config: Optional[QuantConfigBase] = None
cache_quant_dtype: str = "none"
class TmpConfig:
"""
TODO(yuanrisheng):TmpConfig will be moved to other config class when refactor work is relatively complete.
"""
cache_quant_dtype: str = "default"
has_zero_point: bool = False
is_channel_wise: bool = False
weight_block_size: int = 16
use_offline_quant: bool = False
@dataclass
class DecodingConfig:
"""
Configuration for decoding
"""
max_dec_len = 20
min_dec_len = 0
decode_strategy = "sampling"
bos_token_id = None
pad_token_id = None
num_return_sequences: int = 1
@dataclass
class LLMConfig:
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None, init=True)
@@ -499,14 +392,11 @@ class LLMConfig:
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
additional_config: AdditionalConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
tmp_config: TmpConfig = field(default=None, init=True)
moe_config: MoEConfig = field(default=None, init=True) # type: ignore
decoding_config: DecodingConfig = field(default=None,
init=True) # type: ignore
kvcache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore