""" # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import os from dataclasses import dataclass, field from typing import Literal, Optional from paddleformers.transformers.configuration_utils import PretrainedConfig from fastdeploy import envs from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase from fastdeploy.utils import get_logger logger = get_logger("config", "config.log") class MoEPhase: """ The generation phase of the moe. """ def __init__(self, phase="prefill"): self._phase = phase @property def phase(self): return self._phase @phase.setter def phase(self, value): if value not in ["prefill", "decode"]: raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}") else: self._phase = value class ErnieArchitectures: """Helper class for ERNIE architecture check.""" ARCHITECTURES = { "Ernie4_5_ForCausalLM", "Ernie4_5_MoeForCausalLM", "Ernie4_5_VLMoeForConditionalGeneration", } @classmethod def contains_ernie_arch(cls, architectures): """Check if any ERNIE architecture is present in the given architectures.""" return any(arch in architectures for arch in cls.ARCHITECTURES) @classmethod def is_ernie_arch(cls, architecture): """Check if the given architecture is an ERNIE architecture.""" return architecture in cls.ARCHITECTURES PRETRAINED_INIT_CONFIGURATION = { "rope_theta": 10000.0, "num_key_value_heads": -1, "start_layer_index": 0, "moe_num_shared_experts": 0, "moe_layer_start_index": 0, "num_max_dispatch_tokens_per_rank": 256, "moe_use_aux_free": False, "vocab_size": -1, "hidden_dropout_prob": 0.0, "initializer_range": 0.02, "max_position_embeddings": 512, "quantization_config": None, "tie_word_embeddings": False, "rms_norm_eps": 1e-5, "moe_num_experts": None, "moe_layer_end_index": None, } class ModelConfig: """ The configuration class to store the configuration of a `LLM`. """ def __init__( self, args, ): self.max_stop_seqs_num = 5 self.stop_seqs_max_len = 8 # NOTE(gongshaotain): form _load_model_init_val() self.top_p = 1.0 self.temperature = 1.0 self.rope_theta = 10000.0 self.penalty_score = 1.0 self.frequency_score = 0.0 self.presence_score = 0.0 self.min_length = 1 self.model_name_or_path = "" self.is_quantized = False self.max_model_len = 0 self.dtype = "" self.enable_logprob = False self.enable_mm = False self.enable_redundant_experts = False self.redundant_experts_num = 0 for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) assert self.model_name_or_path != "" pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path) self.pretrained_config = PretrainedConfig.from_dict(pretrained_config) # set attribute from pretrained_config for key, value in pretrained_config.items(): setattr(self, key, value) # we need set default value when not exist for key, value in PRETRAINED_INIT_CONFIGURATION.items(): if not hasattr(self, key): setattr(self, key, value) if not hasattr(self, "head_dim"): self.head_dim = self.hidden_size // self.num_attention_heads if hasattr(self, "vision_config"): self.vision_config = PretrainedConfig.from_dict(self.vision_config) self.ori_vocab_size = self.vocab_size if ErnieArchitectures.contains_ernie_arch(self.architectures): self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size) class ParallelConfig: """Configuration for the distributed execution.""" def __init__( self, args, ): self.sequence_parallel = False # Whether to enable sequence parallelism. self.use_ep = False # Whether to enable Expert Parallelism self.moe_phase = MoEPhase("prefill") # Generation phase self.msg_queue_id = 1 # mesage queue id self.tensor_parallel_rank = 0 # TP rank ID self.tensor_parallel_size = 1 # TP degree self.expert_parallel_rank = 0 # EP rank ID self.expert_parallel_size = 1 # EP degree # The embedding weight distributed on your gpu cards is divided by row or column. # Defaults to False means divide by row. When vocab_size can not be divided by world_size # but hidden_size can, we can consider split embedding weight by column. """ From old wersion worker args TODO(gongshaotian): Reclassify """ self.model_name_or_path: str = "./output" self.max_num_seqs: int = 34 # Set default block num for profile run self.total_block_num: int = 2000 # block size self.block_size: int = 64 # Engine worker queue port self.engine_worker_queue_port: int = 9923 # Max model len self.max_model_len: int = 3072 # max_seq_len # cuda visible devices self.device_ids: str = "0" # Input dtype self.dtype: str = "bfloat16" # Encoder's decoder num self.enc_dec_block_num: int = 1 # First token id self.first_token_id: int = 1 # Process ID of engine self.engine_pid: Optional[int] = None # Do profile or not self.do_profile: bool = False # self.pad_token_id: int = -1 # self.eos_tokens_lens: int = 2 self.max_num_batched_tokens: int = 2048 # splitwise role self.splitwise_role: str = "mixed" # guided decoding backend self.guided_decoding_backend: str = None # disable any whitespace for guided decoding self.disable_any_whitespace: bool = True self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.enable_custom_all_reduce: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) self.use_ep = args["expert_parallel_size"] > 1 if self.splitwise_role == "mixed": self.moe_phase = MoEPhase(phase="prefill") elif self.splitwise_role == "prefill": self.moe_phase = MoEPhase(phase="prefill") elif self.splitwise_role == "decode": self.moe_phase = MoEPhase(phase="decode") else: raise NotImplementedError # pd_disaggregation use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0)) if use_pd_disaggregation_per_chunk: self.pd_disaggregation_mode = "per_chunk" elif use_pd_disaggregation: self.pd_disaggregation_mode = "per_query" else: self.pd_disaggregation_mode = "None" class SpeculativeConfig: """ Configuration for speculative decoding. """ def __init__( self, args, ): # speculative method, choose in [None, "ngram_match", "mtp"] self.method: Optional[str] = None # the max length of speculative tokens self.num_speculative_tokens: int = 1 # the max length of candidate tokens for speculative method self.max_candidate_len: int = 5 # the max length of verify window for speculative method self.verify_window: int = 2 # ngram match self.max_ngram_size: int = 5 # model for mtp/eagle/draft_model self.model_name_or_path: Optional[str] = None # quantization of model self.quantization: Optional[str] = None # allocate more blocks to prevent mtp from finishing the block earlier than the main model # Fixed now self.num_gpu_block_expand_ratio: Optional[float] = 1 # To distinguish the main model and draft model(mtp/eagle/draftmodel) # ["main", "mtp"] self.model_type: Optional[str] = "main" # TODO(liuzichang): To reduce memory usage, MTP shares the main model's lm_head and embedding layers. # A trick method is currently used to enable this sharing. # This will be replaced with a more standardized solution in the future. self.sharing_model = None # During benchmarking, we need to enforce that the number of accepted tokens is 1. # This means no tokens from MTP are accepted. # This ensures that the specified simulation acceptance rate is not affected. self.benchmark_mode: bool = False # TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig. # We temperately add the name map here and will delete it in future. name_map = { "speculative_method": "method", "speculative_max_draft_token_num": "num_speculative_tokens", "speculative_model_name_or_path": "model_name_or_path", "speculative_model_quantization": "quantization", "speculative_benchmark_mode": "benchmark_mode", } for key, value in args.items(): if key in name_map.keys() and hasattr(self, name_map[key]): if key == "speculative_benchmark_mode": value = True if value.lower() == "true" else False setattr(self, name_map[key], value) class DeviceConfig: """ Configuration for device settings. """ def __init__( self, args, ): self.device_type = "cuda" for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @dataclass class GraphOptimizationConfig: """ Configuration for compute graph level optimization. """ """The Top-level graph optimization contral corresponds to different backends. - 0: dyncmic graph - 1: static graph - 2: static graph + cinn compilation backend """ graph_opt_level: int = 0 # CUDA Graph Config """ Whether to use cudagraph. - False: cudagraph is not used. - True: cudagraph is used. It requires that all input buffers have fixed addresses, and all splitting ops write their outputs to input buffers. - With dyncmic graph backend: ... - With static grpah backend: WIP """ sot_warmup_sizes: Optional[list[int]] = field(default_factory=list) """ Number of warmup runs for SOT warmup. """ use_cudagraph: bool = False """Sizes to capture cudagraph. - None (default): capture sizes are inferred from llm config. - list[int]: capture sizes are specified as given.""" cudagraph_capture_sizes: Optional[list[int]] = None """ Number of warmup runs for cudagraph. """ cudagraph_num_of_warmups: int = 2 """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True.""" cudagraph_copy_inputs: bool = False """ In static graph, this is an operation list that does not need to be captured by the CUDA graph. CudaGraphBackend will split these operations from the static graph. Example usage: cudagraph_splitting_ops = ["paddle.unified_attention"] Note: If want to use subgraph capture functionality in a dynamic graph, can manually split the model into multiple layers and apply the @support_graph_optimization decorator only to the layer where CUDA graph functionality is required. """ cudagraph_splitting_ops: list[str] = field(default_factory=list) """ Whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops.""" full_cuda_graph: bool = True max_capture_size: int = field(default=None, init=False) # type: ignore batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore # CINN Config ... def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs] dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): logger.info( ("cudagraph sizes specified by model runner" " %s is overridden by config %s"), self.cudagraph_capture_sizes, dedup_sizes, ) self.cudagraph_capture_sizes = dedup_sizes # Sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True) self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0 # Pre-compute the mapping from batch size to padded graph size self.batch_size_to_captured_size = {} for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]): for bs in range(start, end): if bs == start: self.batch_size_to_captured_size[bs] = start else: self.batch_size_to_captured_size[bs] = end self.batch_size_to_captured_size[self.max_capture_size] = self.max_capture_size def _set_cudagraph_sizes(self, max_num_seqs: int = 0): """ Calculate a series of candidate capture batch sizes, and then extract a portion of them as the capture list for the CUDA graph based on user input. """ # Batch Size [1, 2, 4, 8, 16, ... 120, 128] draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)] # Batch Size [128, 144, ... 240, 256] draft_capture_sizes += [16 * i for i in range(9, 17)] # Batch Size [256, 288, ... 992, 1024] draft_capture_sizes += [32 * i for i in range(17, 33)] draft_capture_sizes.append(max_num_seqs) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) class LoadConfig: """ Configuration for dynamic weight loading strategies Attributes: dynamic_load_weight: Whether to enable dynamic weight loading load_strategy: Specifies the weight loading method when enabled: - 'ipc': Real-time IPC streaming with automatic resharding - 'ipc_snapshot': Load from disk snapshot of IPC weights - None: No dynamic loading """ def __init__( self, args, ): self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 self.dynamic_load_weight: bool = False self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) class LoRAConfig: """LoRA Config""" pass class CacheConfig: """ Configuration for the KV cache. Attributes: block_size (int): Size of a cache block in number of tokens. gpu_memory_utilization (float): Fraction of GPU memory to use for model execution. cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'. num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use. Overrides profiled num_gpu_blocks if provided. kv_cache_ratio (float): Ratio for calculating the maximum block number. enc_dec_block_num (int): Number of encoder-decoder blocks. prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding. enable_prefix_caching (bool): Flag to enable prefix caching. """ def __init__(self, args): """ Initialize the CacheConfig class. Args: block_size (int): Size of a cache block in number of tokens. gpu_memory_utilization (float): Fraction of GPU memory to use. cache_dtype (str): Data type for cache storage. Default is 'bfloat16'. num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks. num_cpu_blocks (Optional[int]): Number of CPU blocks. kv_cache_ratio (float): Ratio for max block calculation. enc_dec_block_num (int): Number of encoder-decoder blocks. prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1. enable_prefix_caching (bool): Enable prefix caching. """ self.block_size = 64 self.gpu_memory_utilization = 0.9 self.num_gpu_blocks_override = None self.kv_cache_ratio = 0.75 self.enc_dec_block_num = 2 self.prealloc_dec_block_slot_num_threshold = 5 self.cache_dtype = "bfloat16" self.model_cfg = None self.enable_chunked_prefill = False self.rdma_comm_ports = None self.cache_transfer_protocol = None self.pd_comm_port = None self.enable_prefix_caching = False self.enable_ssd_cache = False self.cache_queue_port = None self.swap_space = None for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) if self.rdma_comm_ports is not None and isinstance(self.rdma_comm_ports, str): self.rdma_comm_ports = self.rdma_comm_ports.split(",") if self.pd_comm_port is not None and isinstance(self.pd_comm_port, str): self.pd_comm_port = [int(port) for port in self.pd_comm_port.split(",")] if self.swap_space is None: self.enable_hierarchical_cache = False else: self.enable_hierarchical_cache = True if self.model_cfg is not None: if hasattr(self.model_cfg, "quantization_config"): self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype) if ( hasattr(self.model_cfg, "num_key_value_heads") and hasattr(self.model_cfg, "num_key_value_heads") and self.model_cfg.num_key_value_heads is not None and int(self.model_cfg.num_key_value_heads) > 0 ): kv_num_head = int(self.model_cfg.num_key_value_heads) else: kv_num_head = self.model_cfg.num_attention_heads self.model_cfg.kv_num_head = kv_num_head # TODO check name if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower(): byte_size = 0.5 self.cache_dtype = "uint8" elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower(): self.cache_dtype = "uint8" byte_size = 1 else: byte_size = 2 self.each_token_cache_space = int( self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size ) self.bytes_per_block = int(self.each_token_cache_space * self.block_size) self.bytes_per_layer_per_block = int( self.block_size * self.model_cfg.kv_num_head * self.model_cfg.head_dim // args["tensor_parallel_size"] * byte_size ) if self.swap_space is None: self.num_cpu_blocks = 0 else: self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block) self._verify_args() def metrics_info(self): """Convert cache_config to dict(key: str, value: str) for prometheus metrics info.""" return {key: str(value) for key, value in self.__dict__.items()} def _verify_args(self): if self.gpu_memory_utilization > 1.0: raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") if self.kv_cache_ratio > 1.0: raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.") def postprocess(self, num_total_tokens, number_of_tasks): """ calculate block num """ self.dec_token_num = self.enc_dec_block_num * self.block_size if self.num_gpu_blocks_override is not None: self.total_block_num = self.num_gpu_blocks_override self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio) else: length = num_total_tokens // number_of_tasks block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size self.total_block_num = block_num * number_of_tasks self.prefill_kvcache_block_num = self.total_block_num logger.info(f"Doing profile, the total_block_num:{self.total_block_num}") def reset(self, num_gpu_blocks): """ reset gpu block number """ self.total_block_num = num_gpu_blocks self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio) logger.info( f"Reset block num, the total_block_num:{self.total_block_num}," f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}" ) def print(self): """ print all config """ logger.info("Cache Configuration Information :") for k, v in self.__dict__.items(): logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=============================================================") class DecodingConfig: """ Configuration for decoding """ def __init__( self, args, ): self.pad_token_id = None for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @dataclass class FDConfig: """ The configuration class which contains all fastdeploy-related configuration. This simplifies passing around the distinct configurations in the codebase. """ model_config: ModelConfig = field(default=None, init=True) # type: ignore parallel_config: ParallelConfig = field(default=None, init=True) speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) quant_config: Optional[QuantConfigBase] = None graph_opt_config: Optional[GraphOptimizationConfig] = None decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore def __post_init__(self): # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn if self.graph_opt_config.graph_opt_level == 2: self.graph_opt_config.graph_opt_level = 1