Files
FastDeploy/fastdeploy/config.py
Jiang-Jia-Jun 05c670e593 [Sync] Update to latest code (#2679)
* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-03 15:43:53 +08:00

415 lines
15 KiB
Python

"""
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Literal
from paddleformers.transformers.configuration_utils import PretrainedConfig
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantConfigBase
from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log")
class MoEPhase(Enum):
"""
The generation phase of the moe.
"""
PREFILL = 1
DECODER = 2
class ModelConfig(PretrainedConfig):
"""
The configuration class to store the configuration of a `LLM`.
"""
max_stop_seqs_num = 5
stop_seqs_max_len = 8
architectures: list[str] = []
# NOTE(gongshaotain): form _load_model_init_val()
top_p = 0.0
temperature = 1.0
rope_theta = 10000.0
penalty_score = 1.0
frequency_score = 0.0
presence_score = 0.0
min_length = 1
def __init__(
self,
vocab_size: int = 100224,
hidden_size: int = 4096,
num_layers: int = 48,
num_attention_heads: int = 32,
num_key_value_heads: Optional[int] = None,
hidden_act: str = "swiglu",
hidden_dropout_prob: float = 0.0,
max_position_embeddings: int = 512,
max_seq_len: int = 512,
initializer_range: float = 0.02,
use_rope=True,
use_fast_ffn: bool = False,
rope_theta: int = 10000,
rope_3d: bool = False,
ori_vocab_size: int | None = None,
moe_layer_start_index: int | None = None,
moe_layer_end_index: int | None = None,
num_hidden_layers: int | None = None,
prefix_name="",
freeze_embedding=False,
rope_head_dim=None,
ffn_hidden_size: Optional[int] = None,
dtype="bfloat16",
start_layer_index: int = 0,
head_dim: Optional[int] = None,
tie_word_embeddings: bool = False,
is_quantized: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
if num_hidden_layers is not None:
self.num_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if head_dim is None:
self.head_dim = self.hidden_size // self.num_attention_heads
else:
self.head_dim = head_dim
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_rope = use_rope
self.use_fast_ffn = use_fast_ffn
self.rope_theta = rope_theta
self.ori_vocab_size = ori_vocab_size or vocab_size
self.max_seq_len = max_seq_len
self.prefix_name = prefix_name
self.freeze_embedding = freeze_embedding
self.rope_head_dim = rope_head_dim
moe_num_experts = kwargs.get("moe_num_experts", 0)
if moe_layer_start_index is not None:
self.moe_layer_start_index = moe_layer_start_index
elif moe_num_experts == 0:
self.moe_layer_start_index = self.num_layers
self.moe_num_experts = 0
if moe_layer_end_index is not None:
self.moe_layer_end_index = moe_layer_end_index
self.ffn_hidden_size = ffn_hidden_size
self.rope_3d = rope_3d
self.start_layer_index = start_layer_index
self.dtype = dtype
self.tie_word_embeddings = tie_word_embeddings
self.is_quantized = is_quantized
@dataclass
class MoEConfig:
"""
Configuration for MoE.
"""
num_experts: int = -1
top_k: int = 8
moe_intermediate_size: int = -1
num_experts_per_rank: int = -1
num_experts_start_offset: int = -1
moe_num_shared_experts = (0, )
moe_layer_start_index = 0
moe_layer_end_index = None
moe_use_aux_free: bool = False
num_max_dispatch_tokens_per_rank = 256
im_patch_id = (
100295 # multimodality, TODO(liuyuanle): read from config.json
)
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
block_size = 16 # The block size for processing.
sequence_parallel = False # Whether to enable sequence parallelism.
use_ep = False # Whether to enable Expert Parallelism
moe_phase = MoEPhase.PREFILL # Generation phase
msg_queue_id = 1 # mesage queue id
tensor_parallel_rank = None # TP rank ID
tensor_parallel_degree = None # TP degree
expert_parallel_rank = None # EP rank ID
expert_parallel_degree = None # EP degree
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
"""
From old wersion worker args
TODO(gongshaotian): Reclassify
"""
model_name_or_path: str = "./output"
max_num_seqs: int = 34
# Set default block num for profile run
max_block_num: int = 2000
# block size
block_size: int = 64
# Engine worker queue port
engine_worker_queue_port: int = 9923
# Max model len
max_model_len: int = 3072 # max_seq_len
# cuda visible devices
device_ids: str = "0"
# Input dtype
dtype: str = "bfloat16"
# Encoder's decoder num
enc_dec_block_num: int = 1
# KV cache ratio for input
kv_cache_ratio: float = 0.7
# First token id
first_token_id: int = 1
# Gpu memory utilization
gpu_memory_utilization: float = 0.9
# Process ID of engine
engine_pid: Optional[int] = None
# Do profile or not
do_profile: bool = False
#
pad_token_id: int = -1
#
eos_tokens_lens: int = 2
# Enable chunked prefill
enable_chunked_prefill: str = "store_true"
#
max_num_batched_tokens: int = 2048
# enable prefix cache
enable_prefix_caching = None
# splitwise role
splitwise_role: str = "mixed"
# guided decoding backend
guided_decoding_backend: str = None
# disable any whitespace for guided decoding
disable_any_whitespace: bool = True
@dataclass
class SpeculativeConfig:
"""
Configuration for speculative decoding.
"""
# speculative method, choose in [None, "ngram_match", "mtp"]
method: Optional[str] = None
# the max length of speculative tokens
num_speculative_tokens: int = 1
# the max length of candidate tokens for speculative method
max_candidate_len: int = 5
# the max length of verify window for speculative method
verify_window: int = 2
# ngram match
max_ngram_size: int = 5
# model for mtp/eagle/draft_model
model_name_or_path: Optional[str] = None
# quantization of model
quantization: Optional[str] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
num_gpu_block_expand_ratio: Optional[float] = 1
# To distinguish the main model and draft model(mtp/eagle/draftmodel)
# ["main", "mtp"]
model_type: Optional[str] = "main"
# TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers.
# A trick method is currently used to enable this sharing.
# This will be replaced with a more standardized solution in the future.
sharing_model = None
@dataclass
class DeviceConfig:
"""
Configuration for device settings.
"""
device_type = "cuda"
class GraphOptimizationConfig:
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops = Optional[list[str]]
""""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = False
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
def init_with_cudagrpah_size(self,
cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes"""
if self.cudagraph_capture_sizes is None:
self.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
# sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.batch_size_to_captured_size[bs] = start
else:
self.batch_size_to_captured_size[bs] = end
self.batch_size_to_captured_size[
self.max_capture_size] = self.max_capture_size
def __init__(self,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64):
""" """
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
self.use_cudagraph = use_cudagraph
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if enable_static_graph_inference:
self.graph_opt_level = 1
@dataclass
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
Attributes:
dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_no_reshard': Real-time IPC streaming without weight process
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading
"""
use_fastsafetensor: bool = False
dynamic_load_weight: bool = False
load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None
def __post_init__(self):
if self.load_strategy is not None and not self.dynamic_load_weight:
raise ValueError("Load strategy requires dynamic_load_weight=True")
if self.dynamic_load_weight and self.load_strategy is None:
raise ValueError("Must specify load_strategy when dynamic_load_weight is True")
@dataclass
class LoRAConfig:
""" LoRA Config """
pass
@dataclass
class KVCacheConfig:
""" KV Cache Config """
cache_quant_dtype: str = "none"
@dataclass
class DecodingConfig:
"""
Configuration for decoding
"""
pad_token_id = None
@dataclass
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None, init=True)
speculative_config: SpeculativeConfig = field(default=None,
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
moe_config: MoEConfig = field(default=None, init=True) # type: ignore
decoding_config: DecodingConfig = field(default=None,
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore