mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -1,21 +1,28 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import (check_unified_ckpt, get_host_ip,
|
||||
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
|
||||
is_port_available, llm_logger)
|
||||
|
||||
TaskOption = Literal["generate"]
|
||||
@@ -23,37 +30,37 @@ TaskOption = Literal["generate"]
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Configuration class for model settings and parameters.
|
||||
Configuration class for the model.
|
||||
|
||||
Attributes:
|
||||
model_dir (str): Path to the model directory
|
||||
is_unified_ckpt (bool): Whether the checkpoint uses unified format
|
||||
model_name_or_path (str): Model identifier or path
|
||||
dynamic_load_weight (int): Dynamic weight loading flag
|
||||
"""
|
||||
Attributes:
|
||||
model_dir (str): Directory path to the model.
|
||||
is_unified_ckpt (bool): Flag indicating if the checkpoint is unified.
|
||||
model_name_or_path (str): Name or path of the model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_or_path: str,
|
||||
config_json_file: str = "config.json",
|
||||
dynamic_load_weight: int = 0,
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize model configuration.
|
||||
Initialize the ModelConfig class.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): Model identifier or path
|
||||
config_json_file (str): Model config file name (default: 'config.json')
|
||||
dynamic_load_weight (int): Dynamic weight loading mode (default: 0)
|
||||
download_dir (Optional[str]): Directory for downloaded models (default: None)
|
||||
model_name_or_path (str): Name or path of the model.
|
||||
config_json_file (str): Path to the configuration JSON file. Default is 'config.json'.
|
||||
download_dir (Optional[str]): Directory to download model files. Default is None.
|
||||
"""
|
||||
self.model_dir = model_name_or_path
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.quantization = quantization
|
||||
|
||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||
if os.path.isfile(model_name_or_path):
|
||||
try:
|
||||
from paddlenlp.transformers import AutoConfig
|
||||
from paddleformers.transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
config_dict = {
|
||||
k: v
|
||||
@@ -63,10 +70,10 @@ Attributes:
|
||||
setattr(self, key, value)
|
||||
except Exception:
|
||||
llm_logger.error(
|
||||
"Don't support the current model, you can use `paddlenlp` to register your model."
|
||||
"Don't support the current model, you can use `paddleformers` to register your model."
|
||||
)
|
||||
raise ValueError(
|
||||
"Don't support the current model, you can use `paddlenlp` to register your model."
|
||||
"Don't support the current model, you can use `paddleformers` to register your model."
|
||||
)
|
||||
else:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
@@ -85,12 +92,9 @@ Attributes:
|
||||
|
||||
def override_name_from_config(self):
|
||||
"""
|
||||
Update attribute names from model configuration.
|
||||
Handles special cases like:
|
||||
- Renaming infer_model_mp_num to tensor_parallel_size
|
||||
- Adjusting num_hidden_layers based on remove_tail_layer
|
||||
- Setting default mla_use_absorb value
|
||||
Override attribute names from the exported model's configuration.
|
||||
"""
|
||||
|
||||
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
|
||||
self.tensor_parallel_size = self.infer_model_mp_num
|
||||
del self.infer_model_mp_num
|
||||
@@ -107,27 +111,19 @@ Attributes:
|
||||
|
||||
if not hasattr(self, "mla_use_absorb"):
|
||||
self.mla_use_absorb = False
|
||||
if not hasattr(self, "head_dim"):
|
||||
assert hasattr(self, "hidden_size") and hasattr(
|
||||
self, "num_attention_heads")
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
def read_from_env(self):
|
||||
"""
|
||||
Load configuration from environment variables.
|
||||
Sets default values if env vars not found.
|
||||
Reads:
|
||||
- MAX_STOP_SEQS_NUM (default: 5)
|
||||
- STOP_SEQS_MAX_LEN (default: 8)
|
||||
- ELLM_DYNAMIC_QUANT_TYPE (default: 'default')
|
||||
- ELLM_DYNAMIC_USE_STOP_SEQS (default: 0)
|
||||
- COMPRESSION_RATIO (default: 1.0)
|
||||
- ROPE_THETA (default: 10000)
|
||||
"""
|
||||
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", "5"))
|
||||
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", "8"))
|
||||
Read configuration information from environment variables and update the object's attributes.
|
||||
|
||||
self.ellm_dynamic_quant_type = os.getenv("ELLM_DYNAMIC_QUANT_TYPE",
|
||||
"default")
|
||||
# Whether to use stop sequences in dynamic graph inference
|
||||
self.ellm_dynamic_use_stop_seqs = int(
|
||||
os.getenv("ELLM_DYNAMIC_USE_STOP_SEQS", "0")) == 1
|
||||
If an attribute is not present or is an empty string in the environment variables, use the default value.
|
||||
"""
|
||||
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
|
||||
def reset_config_value(key, value):
|
||||
if not hasattr(self, key.lower()):
|
||||
@@ -144,13 +140,12 @@ Attributes:
|
||||
reset_config_value("ROPE_THETA", 10000)
|
||||
|
||||
def _get_download_model(self, model_name, model_type="default"):
|
||||
# TODO: Implement dynamic graph for self-downloading models
|
||||
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
||||
pass
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print current model configuration.
|
||||
Logs all attributes and their values.
|
||||
Print all configuration information.
|
||||
"""
|
||||
llm_logger.info("Model Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -161,19 +156,18 @@ Attributes:
|
||||
|
||||
class CacheConfig:
|
||||
"""
|
||||
Configuration for key-value cache management.
|
||||
Configuration for the KV cache.
|
||||
|
||||
Attributes:
|
||||
block_size (int): Tokens per cache block
|
||||
gpu_memory_utilization (float): GPU memory usage fraction (0-1)
|
||||
cache_dtype (str): Data type for cache (default: 'bfloat16')
|
||||
num_gpu_blocks_override (Optional[int]): Manual GPU blocks override
|
||||
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
||||
enc_dec_block_num (int): Encoder-decoder blocks count
|
||||
enable_prefix_caching (bool): Prefix caching enable flag
|
||||
total_block_num (int): Total available blocks
|
||||
prefill_kvcache_block_num (int): Blocks allocated for prefill
|
||||
"""
|
||||
Attributes:
|
||||
block_size (int): Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
|
||||
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
|
||||
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
|
||||
Overrides profiled num_gpu_blocks if provided.
|
||||
kv_cache_ratio (float): Ratio for calculating the maximum block number.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
enable_prefix_caching (bool): Flag to enable prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -181,21 +175,31 @@ Attributes:
|
||||
gpu_memory_utilization: float,
|
||||
cache_dtype: str = "bfloat16",
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
swap_space: Optional[int] = None,
|
||||
kv_cache_ratio: float = 0.75,
|
||||
enc_dec_block_num: int = 2,
|
||||
enable_prefix_caching: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
enable_prefix_caching=False,
|
||||
enable_ssd_cache=False,
|
||||
model_cfg=None,
|
||||
cache_queue_port=None,
|
||||
enable_chunked_prefill=False,
|
||||
rdma_comm_ports=None,
|
||||
cache_transfer_protocol=None,
|
||||
pd_comm_port=None,
|
||||
):
|
||||
"""
|
||||
Initialize cache configuration.
|
||||
Initialize the CacheConfig class.
|
||||
|
||||
Args:
|
||||
block_size (int): Tokens per cache block
|
||||
gpu_memory_utilization (float): GPU memory usage target (0-1)
|
||||
cache_dtype (str): Cache data type (default: 'bfloat16')
|
||||
num_gpu_blocks_override (Optional[int]): Manual GPU blocks setting
|
||||
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
||||
enc_dec_block_num (int): Encoder-decoder blocks count
|
||||
enable_prefix_caching (bool): Enable prefix sharing
|
||||
block_size (int): Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization (float): Fraction of GPU memory to use.
|
||||
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
|
||||
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
|
||||
num_cpu_blocks (Optional[int]): Number of CPU blocks.
|
||||
kv_cache_ratio (float): Ratio for max block calculation.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
enable_prefix_caching (bool): Enable prefix caching.
|
||||
"""
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
@@ -203,20 +207,74 @@ Attributes:
|
||||
self.kv_cache_ratio = kv_cache_ratio
|
||||
self.enc_dec_block_num = enc_dec_block_num
|
||||
self.cache_dtype = cache_dtype
|
||||
if hasattr(model_cfg, "quantization_config"):
|
||||
self.cache_dtype = model_cfg.quantization_config.get(
|
||||
"kv_cache_quant_type", cache_dtype)
|
||||
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.rdma_comm_ports = rdma_comm_ports
|
||||
self.cache_transfer_protocol = cache_transfer_protocol
|
||||
self.pd_comm_port = pd_comm_port
|
||||
|
||||
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
|
||||
self.rdma_comm_ports = rdma_comm_ports.split(',')
|
||||
|
||||
if pd_comm_port is not None and isinstance(pd_comm_port, str):
|
||||
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
|
||||
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
if swap_space is None:
|
||||
self.enable_hierarchical_cache = False
|
||||
else:
|
||||
self.enable_hierarchical_cache = True
|
||||
|
||||
self.enable_ssd_cache = enable_ssd_cache
|
||||
self.model_cfg = model_cfg
|
||||
self.cache_queue_port = cache_queue_port
|
||||
self.swap_space = swap_space
|
||||
|
||||
if (hasattr(self.model_cfg, "num_key_value_heads")
|
||||
and hasattr(self.model_cfg, "num_key_value_heads")
|
||||
and self.model_cfg.num_key_value_heads is not None
|
||||
and int(self.model_cfg.num_key_value_heads) > 0):
|
||||
kv_num_head = int(self.model_cfg.num_key_value_heads)
|
||||
else:
|
||||
kv_num_head = self.model_cfg.num_attention_heads
|
||||
self.model_cfg.kv_num_head = kv_num_head
|
||||
|
||||
# TODO check name
|
||||
if "int4" in self.cache_dtype.lower(
|
||||
) or "float4" in self.cache_dtype.lower():
|
||||
byte_size = 0.5
|
||||
self.cache_dtype = "uint8"
|
||||
elif "int8" in self.cache_dtype.lower(
|
||||
) or "float8" in self.cache_dtype.lower():
|
||||
self.cache_dtype = "uint8"
|
||||
byte_size = 1
|
||||
else:
|
||||
byte_size = 2
|
||||
|
||||
self.each_token_cache_space = int(
|
||||
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim *
|
||||
byte_size)
|
||||
self.bytes_per_block = int(self.each_token_cache_space *
|
||||
self.block_size)
|
||||
self.bytes_per_layer_per_block = int(
|
||||
self.block_size * self.model_cfg.kv_num_head *
|
||||
self.model_cfg.head_dim // tensor_parallel_size * byte_size)
|
||||
|
||||
if self.swap_space is None:
|
||||
self.num_cpu_blocks = 0
|
||||
else:
|
||||
self.num_cpu_blocks = int(self.swap_space * 1024**3 /
|
||||
self.bytes_per_block)
|
||||
self._verify_args()
|
||||
|
||||
def metrics_info(self):
|
||||
"""
|
||||
Convert config to metrics dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Key-value pairs of all config attributes
|
||||
"""
|
||||
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
def _verify_args(self):
|
||||
"""Validate configuration arguments."""
|
||||
if self.gpu_memory_utilization > 1.0:
|
||||
raise ValueError(
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
@@ -227,46 +285,38 @@ Attributes:
|
||||
|
||||
def postprocess(self, num_total_tokens, number_of_tasks):
|
||||
"""
|
||||
Calculate block allocation based on tokens and tasks.
|
||||
|
||||
Args:
|
||||
num_total_tokens (int): Total tokens to process
|
||||
number_of_tasks (int): Number of parallel tasks
|
||||
|
||||
Sets:
|
||||
dec_token_num (int): Decoder tokens per block
|
||||
total_block_num (int): Total blocks needed
|
||||
prefill_kvcache_block_num (int): Blocks for prefill phase
|
||||
calculate block num
|
||||
"""
|
||||
self.dec_token_num = self.enc_dec_block_num * self.block_size
|
||||
if self.num_gpu_blocks_override is not None:
|
||||
self.total_block_num = self.num_gpu_blocks_override
|
||||
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
||||
self.prefill_kvcache_block_num = int(self.total_block_num *
|
||||
self.kv_cache_ratio)
|
||||
else:
|
||||
length = num_total_tokens // number_of_tasks
|
||||
block_num = (length + self.block_size - 1 + self.enc_dec_block_num) // self.block_size
|
||||
self.total_block_num = block_num * number_of_tasks
|
||||
self.prefill_kvcache_block_num= self.total_block_num
|
||||
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
|
||||
block_num = (length + self.block_size - 1 +
|
||||
self.dec_token_num) // self.block_size
|
||||
self.total_block_num = block_num * number_of_tasks
|
||||
self.prefill_kvcache_block_num = self.total_block_num
|
||||
llm_logger.info(
|
||||
f"Doing profile, the total_block_num:{self.total_block_num}")
|
||||
|
||||
def reset(self, num_gpu_blocks):
|
||||
"""
|
||||
Reset GPU block allocation.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks (int): New total blocks count
|
||||
|
||||
Updates:
|
||||
total_block_num (int)
|
||||
prefill_kvcache_block_num (int)
|
||||
reset gpu block number
|
||||
"""
|
||||
self.total_block_num = num_gpu_blocks
|
||||
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
||||
llm_logger.info((f"Reset block num, the total_block_num:{self.total_block_num},"
|
||||
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
|
||||
self.total_block_num = num_gpu_blocks
|
||||
self.prefill_kvcache_block_num = int(self.total_block_num *
|
||||
self.kv_cache_ratio)
|
||||
llm_logger.info(
|
||||
(f"Reset block num, the total_block_num:{self.total_block_num},"
|
||||
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
|
||||
|
||||
def print(self):
|
||||
"""Print current cache configuration."""
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Cache Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
@@ -274,38 +324,182 @@ Attributes:
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""
|
||||
Speculative Decoding Configuration class.
|
||||
|
||||
Attributes:
|
||||
method (Optional[str]): Method used for speculative decoding.
|
||||
num_speculative_tokens (int): Maximum draft tokens, default is 1.
|
||||
model_name_or_path (Optional[str]): Path of the model.
|
||||
quantization (str): Quantization method for draft model, default is WINT8.
|
||||
max_model_len: Optional[int]: Maximum model length for draft model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
method: Optional[str] = None,
|
||||
num_speculative_tokens: Optional[int] = 1,
|
||||
model: Optional[str] = None,
|
||||
quantization: Optional[str] = "WINT8",
|
||||
max_model_len: Optional[int] = None,
|
||||
**kwargs):
|
||||
self.model_name_or_path = model
|
||||
self.method = method
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.quantization = quantization
|
||||
self.max_model_len = max_model_len
|
||||
# Fixed now
|
||||
self.num_gpu_block_expand_ratio = 1
|
||||
self.num_extra_cache_layer = 0
|
||||
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
self.read_model_config()
|
||||
self.reset()
|
||||
|
||||
def read_model_config(self):
|
||||
"""
|
||||
Read configuration from file.
|
||||
"""
|
||||
self.model_config = {}
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
|
||||
if self.model_name_or_path is None:
|
||||
return
|
||||
|
||||
self.config_path = os.path.join(self.model_name_or_path, "config.json")
|
||||
if os.path.exists(self.config_path):
|
||||
self.model_config = json.load(
|
||||
open(self.config_path, 'r', encoding='utf-8'))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset configuration.
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key=None, default=None):
|
||||
if key is not None and key in cls.model_config:
|
||||
setattr(cls, value_name, cls.model_config[key])
|
||||
elif getattr(cls, value_name, None) is None:
|
||||
setattr(cls, value_name, default)
|
||||
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
# NOTE(liuzichang): We will support multi-layer in future
|
||||
if self.method in ["mtp"]:
|
||||
self.num_extra_cache_layer = 1
|
||||
|
||||
def enabled_speculative_decoding(self):
|
||||
"""
|
||||
Check if speculative decoding is enabled.
|
||||
"""
|
||||
if self.method is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({
|
||||
key: value
|
||||
for key, value in self.__dict__.items() if value is not None
|
||||
})
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Speculative Decoding Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for parallelism.
|
||||
|
||||
Attributes:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_parallel_size: int = 1,
|
||||
data_parallel_size: int = 1,
|
||||
enable_expert_parallel: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ParallelConfig class.
|
||||
|
||||
Args:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.enable_expert_parallel = enable_expert_parallel
|
||||
self.expert_parallel_size = data_parallel_size
|
||||
self.local_data_parallel_id = 0
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("==================")
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Main engine configuration class combining all components.
|
||||
Initial configuration class.
|
||||
|
||||
Attributes:
|
||||
model_config (ModelConfig): Model settings
|
||||
cache_config (CacheConfig): Cache management settings
|
||||
scheduler_config (SchedulerConfig): Task scheduling settings
|
||||
model_name_or_path (str): Model identifier/path
|
||||
tokenizer (str): Tokenizer identifier
|
||||
tensor_parallel_size (int): Parallelism degree (default: 8)
|
||||
nnode (int): Node count (default: 1)
|
||||
max_model_len (int): Max sequence length (default: 8192)
|
||||
max_num_seqs (int): Max concurrent sequences (default: 8)
|
||||
max_num_batched_tokens (Optional[int]): Max batched tokens
|
||||
pod_ips (Optional[List[str]]): Cluster node IPs
|
||||
mm_processor_kwargs (Optional[Dict]): Multi-modal processor args
|
||||
speculative_config (Optional[Dict]): Speculative execution settings
|
||||
use_warmup (bool): Warmup enable flag
|
||||
enable_mm (bool): Multi-modal enable flag
|
||||
enable_chunked_prefill (bool): Chunked prefill enable flag
|
||||
device_ids (str): GPU device IDs
|
||||
tp_num_per_node (int): Tensor parallelism per node
|
||||
host_ip (str): Current host IP
|
||||
paddle_commit_id (str): PaddlePaddle version
|
||||
"""
|
||||
Attributes:
|
||||
model_config (ModelConfig): Model configuration object.
|
||||
cache_config (CacheConfig): Cache configuration object.
|
||||
model_name_or_path (str): Directory path to the model or the model name.
|
||||
tokenizer (Optional[str]): Default is the model.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens.
|
||||
tensor_parallel_size (int): Tensor parallel size.
|
||||
nnode (int): Number of nodes.
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration.
|
||||
use_warmup (bool): Flag to use warmup.
|
||||
engine_worker_queue_port (int): Port for engine worker queue.
|
||||
enable_mm (bool): Flag to enable multi-modal processing.
|
||||
reasoning_parser(str): Flag specifies the reasoning parser to use for
|
||||
extracting reasoning content from the model output
|
||||
splitwise_role (str): Splitwise role.
|
||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
|
||||
Temporary configuration, will be removed in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
@@ -314,38 +508,57 @@ Attributes:
|
||||
max_num_seqs: int = 8,
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
pod_ips: Optional[List[str]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
speculative_config: Optional[Dict[str, Any]] = None,
|
||||
use_warmup: bool = False,
|
||||
engine_worker_queue_port: int = 8002,
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
enable_mm: bool = False,
|
||||
enable_chunked_prefill: bool = False,
|
||||
splitwise_role: str = "mixed",
|
||||
innode_prefill_ports: Optional[List[int]] = None,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
reasoning_parser: str = None,
|
||||
enable_static_graph_inference: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize engine configuration.
|
||||
Initialize the Config class.
|
||||
|
||||
Args:
|
||||
model_config (ModelConfig): Model settings
|
||||
cache_config (CacheConfig): Cache settings
|
||||
scheduler_config (SchedulerConfig): Scheduler settings
|
||||
model_name_or_path (str): Model identifier (default: None)
|
||||
tokenizer (str): Tokenizer identifier (default: None)
|
||||
tensor_parallel_size (int): Parallelism degree (default: 8)
|
||||
nnode (int): Node count (default: 1)
|
||||
max_model_len (int): Max sequence length (default: 8192)
|
||||
max_num_seqs (int): Max concurrent sequences (default: 8)
|
||||
max_num_batched_tokens (Optional[int]): Max batched tokens (default: None)
|
||||
pod_ips (Optional[List[str]]): Cluster node IPs (default: None)
|
||||
mm_processor_kwargs (Optional[Dict]): Multi-modal args (default: None)
|
||||
speculative_config (Optional[Dict]): Speculative settings (default: None)
|
||||
use_warmup (bool): Warmup flag (default: False)
|
||||
engine_worker_queue_port (int): Worker queue port (default: 8002)
|
||||
enable_mm (bool): Multi-modal flag (default: False)
|
||||
enable_chunked_prefill (bool): Chunked prefill flag (default: False)
|
||||
model_config (ModelConfig): Model configuration object.
|
||||
cache_config (CacheConfig): Cache configuration object.
|
||||
parallel_config (ParallelConfig): Parallel configuration object.
|
||||
scheduler_config (SchedulerConfig): Scheduler configuration object.
|
||||
model_name_or_path (str): Model directory path or model name.
|
||||
tokenizer (str): Default is the model.
|
||||
tensor_parallel_size (int): Tensor parallel size. Default is 8.
|
||||
nnode (int): Number of nodes. Default is 1.
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
|
||||
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
|
||||
use_warmup (bool): Flag to use warmup. Default is False.
|
||||
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
|
||||
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
|
||||
splitwise_role (str): Splitwise role. Default is "mixed".
|
||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Default is None.
|
||||
reasoning_parser (str): Flag specifies the reasoning parser to use for
|
||||
extracting reasoning content from the model output. Default is None.
|
||||
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
||||
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
||||
Default is False.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.parallel_config = parallel_config
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -354,20 +567,56 @@ Attributes:
|
||||
self.pod_ips = pod_ips
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.enable_mm = enable_mm
|
||||
self.speculative_config = speculative_config
|
||||
self.use_warmup = use_warmup
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.splitwise_role = splitwise_role
|
||||
self.innode_prefill_ports = innode_prefill_ports
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.enable_static_graph_inference = enable_static_graph_inference
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
|
||||
|
||||
if self.innode_prefill_ports is not None:
|
||||
if not isinstance(self.innode_prefill_ports, list):
|
||||
ports = str(self.innode_prefill_ports).split(',')
|
||||
self.innode_prefill_ports = [int(port) for port in ports]
|
||||
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
# TODO
|
||||
self.max_prefill_batch = 3
|
||||
if current_platform.is_xpu():
|
||||
self.max_prefill_batch = 1
|
||||
if enable_mm:
|
||||
self.max_prefill_batch = 1 # TODO: Currently multi-modal prefill only supports parallelism=1 (needs optimization)
|
||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||
|
||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||
assert (self.tensor_parallel_size == 1
|
||||
and self.parallel_config.expert_parallel_size
|
||||
>= 1) or (self.tensor_parallel_size >= 1
|
||||
and self.parallel_config.expert_parallel_size
|
||||
== 1), "TP and EP cannot be enabled at the same time"
|
||||
|
||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
if num_ranks > 8:
|
||||
local_num_ranks = 8
|
||||
self.nnode = ceil_div(num_ranks, local_num_ranks)
|
||||
else:
|
||||
local_num_ranks = num_ranks
|
||||
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self.device_ids = ",".join(
|
||||
[str(i) for i in range(self.tensor_parallel_size)])
|
||||
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
|
||||
self.parallel_config.expert_parallel_size), 8))])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.read_from_config()
|
||||
@@ -377,45 +626,44 @@ Attributes:
|
||||
|
||||
def postprocess(self):
|
||||
"""
|
||||
Calculate derived parameters:
|
||||
- Validates GPU device count matches tensor_parallel_size
|
||||
- Computes tensor parallelism per node
|
||||
- Gets host IP and Paddle version
|
||||
- Sets default max_num_batched_tokens if not provided
|
||||
- Initializes cache configuration
|
||||
calculate some parameters
|
||||
"""
|
||||
if len(self.device_ids.split(',')) > self.tensor_parallel_size:
|
||||
self.device_ids = ",".join(
|
||||
self.device_ids.split(',')[:self.tensor_parallel_size:])
|
||||
assert len(
|
||||
self.device_ids.split(',')
|
||||
) == self.tensor_parallel_size, f"The number of available GPUs is {len(self.device_ids.split(','))}, which is less than the tensor parallel required {self.tensor_parallel_size}."
|
||||
|
||||
assert self.tensor_parallel_size % self.nnode == 0, f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
||||
self.tp_num_per_node = self.tensor_parallel_size // self.nnode
|
||||
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
|
||||
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
|
||||
self.local_device_ids = self.device_ids.split(
|
||||
',')[:self.tensor_parallel_size]
|
||||
assert self.tensor_parallel_size % self.nnode == 0, \
|
||||
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
||||
self.worker_num_per_node = total_rank // self.nnode
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
import paddle
|
||||
self.paddle_commit_id = paddle.version.commit
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens,
|
||||
self.max_num_seqs)
|
||||
self.cache_config.max_block_num_per_seq = int(
|
||||
self.max_model_len // self.cache_config.block_size)
|
||||
|
||||
if self.guided_decoding_backend == "auto":
|
||||
if self.enable_mm:
|
||||
self.guided_decoding_backend = "off"
|
||||
else:
|
||||
self.guided_decoding_backend = "xgrammar"
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Validate configuration values:
|
||||
- max_num_seqs <= 256
|
||||
- engine_worker_queue_port available
|
||||
- 1 <= tensor_parallel_size <= 8
|
||||
- nnode >= 1
|
||||
- max_model_len >= 16
|
||||
- max_num_seqs >= 1
|
||||
- Validates scheduler configuration
|
||||
check the legality of config
|
||||
"""
|
||||
assert (
|
||||
self.max_num_seqs <= 256
|
||||
@@ -434,15 +682,66 @@ Attributes:
|
||||
assert (
|
||||
self.max_num_seqs
|
||||
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
||||
assert (
|
||||
self.max_num_batched_tokens >= self.max_num_seqs
|
||||
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
|
||||
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
|
||||
assert (self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs), \
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" \
|
||||
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
|
||||
assert (
|
||||
self.max_num_partial_prefills >= 1
|
||||
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
|
||||
|
||||
assert (
|
||||
self.max_long_partial_prefills >= 1
|
||||
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
|
||||
assert (self.max_long_partial_prefills <= self.max_num_partial_prefills), \
|
||||
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " \
|
||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
assert (
|
||||
self.max_num_batched_tokens >= self.max_model_len
|
||||
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
assert (self.cache_config.enable_chunked_prefill is True), \
|
||||
"Chunked prefill must be enabled to set max_num_partial_prefills > 1"
|
||||
assert (self.long_prefill_token_threshold < self.max_model_len), \
|
||||
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"\
|
||||
f" max_model_len: {self.max_model_len}"
|
||||
|
||||
if self.guided_decoding_backend is not None:
|
||||
assert self.guided_decoding_backend in ["xgrammar", "XGrammar", "auto", "off"], \
|
||||
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
||||
|
||||
if self.guided_decoding_backend != "off":
|
||||
# TODO: mm support guided_decoding
|
||||
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
|
||||
|
||||
# TODO: speculative decoding support guided_decoding
|
||||
|
||||
# TODO: xpu support guided_decoding
|
||||
assert not current_platform.is_xpu(
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
)
|
||||
|
||||
self.scheduler_config.check()
|
||||
|
||||
def print(self, file=None):
|
||||
"""
|
||||
Print or save current configuration.
|
||||
|
||||
print all config
|
||||
|
||||
Args:
|
||||
file (Optional[str]): File path to save config (default: None)
|
||||
file (str): the path of file to save config
|
||||
"""
|
||||
llm_logger.info(
|
||||
"=================== Configuration Information ===============")
|
||||
@@ -450,7 +749,7 @@ Attributes:
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config":
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
|
||||
v.print()
|
||||
else:
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
@@ -464,13 +763,36 @@ Attributes:
|
||||
f.write("{:<20}:{:<6}{}\n".format(k, "", v))
|
||||
f.close()
|
||||
|
||||
def init_cache_info(self):
|
||||
"""
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
if self.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(
|
||||
",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.cache_config.pd_comm_port[0],
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
self.disaggregate_info = disaggregate_info
|
||||
llm_logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
Update configuration from model JSON file.
|
||||
Handles special cases:
|
||||
- infer_model_block_size -> block_size
|
||||
- return_full_hidden_states
|
||||
- infer_model_dtype -> cache_dtype
|
||||
reset model config from json file
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key):
|
||||
@@ -482,9 +804,9 @@ Attributes:
|
||||
)
|
||||
|
||||
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
||||
reset_value(self.model_config, "return_full_hidden_states", "return_full_hidden_states")
|
||||
reset_value(self.model_config, "return_full_hidden_states",
|
||||
"return_full_hidden_states")
|
||||
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.__dict__, indent=4)
|
||||
|
||||
|
Reference in New Issue
Block a user