mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
491 lines
20 KiB
Python
491 lines
20 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from fastdeploy.scheduler import SchedulerConfig
|
|
from fastdeploy.utils import (check_unified_ckpt, get_host_ip,
|
|
is_port_available, llm_logger)
|
|
|
|
TaskOption = Literal["generate"]
|
|
|
|
|
|
class ModelConfig:
|
|
"""
|
|
Configuration class for model settings and parameters.
|
|
|
|
Attributes:
|
|
model_dir (str): Path to the model directory
|
|
is_unified_ckpt (bool): Whether the checkpoint uses unified format
|
|
model_name_or_path (str): Model identifier or path
|
|
dynamic_load_weight (int): Dynamic weight loading flag
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_name_or_path: str,
|
|
config_json_file: str = "config.json",
|
|
dynamic_load_weight: int = 0,
|
|
download_dir: Optional[str] = None):
|
|
"""
|
|
Initialize model configuration.
|
|
|
|
Args:
|
|
model_name_or_path (str): Model identifier or path
|
|
config_json_file (str): Model config file name (default: 'config.json')
|
|
dynamic_load_weight (int): Dynamic weight loading mode (default: 0)
|
|
download_dir (Optional[str]): Directory for downloaded models (default: None)
|
|
"""
|
|
self.model_dir = model_name_or_path
|
|
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
|
|
self.dynamic_load_weight = dynamic_load_weight
|
|
|
|
config_file = os.path.join(model_name_or_path, config_json_file)
|
|
if os.path.isfile(model_name_or_path):
|
|
try:
|
|
from paddlenlp.transformers import AutoConfig
|
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
config_dict = {
|
|
k: v
|
|
for k, v in vars(config).items() if not k.startswith('_')
|
|
}
|
|
for key, value in config_dict.items():
|
|
setattr(self, key, value)
|
|
except Exception:
|
|
llm_logger.error(
|
|
"Don't support the current model, you can use `paddlenlp` to register your model."
|
|
)
|
|
raise ValueError(
|
|
"Don't support the current model, you can use `paddlenlp` to register your model."
|
|
)
|
|
else:
|
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
config_dict = json.load(f)
|
|
for key, value in config_dict.items():
|
|
try:
|
|
setattr(self, key, value)
|
|
except Exception:
|
|
continue
|
|
|
|
if isinstance(self.architectures, list):
|
|
self.architectures = self.architectures[0]
|
|
self.model_name_or_path = model_name_or_path
|
|
self.override_name_from_config()
|
|
self.read_from_env()
|
|
|
|
def override_name_from_config(self):
|
|
"""
|
|
Update attribute names from model configuration.
|
|
Handles special cases like:
|
|
- Renaming infer_model_mp_num to tensor_parallel_size
|
|
- Adjusting num_hidden_layers based on remove_tail_layer
|
|
- Setting default mla_use_absorb value
|
|
"""
|
|
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
|
|
self.tensor_parallel_size = self.infer_model_mp_num
|
|
del self.infer_model_mp_num
|
|
|
|
if hasattr(self, "num_hidden_layers"):
|
|
if hasattr(self, "remove_tail_layer"):
|
|
if self.remove_tail_layer is True:
|
|
self.num_hidden_layers -= 1
|
|
elif isinstance(self.remove_tail_layer, int):
|
|
self.num_hidden_layers -= self.remove_tail_layer
|
|
|
|
self.num_layers = self.num_hidden_layers
|
|
del self.num_hidden_layers
|
|
|
|
if not hasattr(self, "mla_use_absorb"):
|
|
self.mla_use_absorb = False
|
|
|
|
def read_from_env(self):
|
|
"""
|
|
Load configuration from environment variables.
|
|
Sets default values if env vars not found.
|
|
Reads:
|
|
- MAX_STOP_SEQS_NUM (default: 5)
|
|
- STOP_SEQS_MAX_LEN (default: 8)
|
|
- ELLM_DYNAMIC_QUANT_TYPE (default: 'default')
|
|
- ELLM_DYNAMIC_USE_STOP_SEQS (default: 0)
|
|
- COMPRESSION_RATIO (default: 1.0)
|
|
- ROPE_THETA (default: 10000)
|
|
"""
|
|
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", "5"))
|
|
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", "8"))
|
|
|
|
self.ellm_dynamic_quant_type = os.getenv("ELLM_DYNAMIC_QUANT_TYPE",
|
|
"default")
|
|
# Whether to use stop sequences in dynamic graph inference
|
|
self.ellm_dynamic_use_stop_seqs = int(
|
|
os.getenv("ELLM_DYNAMIC_USE_STOP_SEQS", "0")) == 1
|
|
|
|
def reset_config_value(key, value):
|
|
if not hasattr(self, key.lower()):
|
|
if os.getenv(key, None):
|
|
value = eval(os.getenv(key))
|
|
llm_logger.info(
|
|
f"Get parameter `{key}` = {value} from environment.")
|
|
else:
|
|
llm_logger.info(
|
|
f"Parameter `{key}` will use default value {value}.")
|
|
setattr(self, key.lower(), value)
|
|
|
|
reset_config_value("COMPRESSION_RATIO", 1.0)
|
|
reset_config_value("ROPE_THETA", 10000)
|
|
|
|
def _get_download_model(self, model_name, model_type="default"):
|
|
# TODO: Implement dynamic graph for self-downloading models
|
|
pass
|
|
|
|
def print(self):
|
|
"""
|
|
Print current model configuration.
|
|
Logs all attributes and their values.
|
|
"""
|
|
llm_logger.info("Model Configuration Information :")
|
|
for k, v in self.__dict__.items():
|
|
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
|
llm_logger.info(
|
|
"=============================================================")
|
|
|
|
|
|
class CacheConfig:
|
|
"""
|
|
Configuration for key-value cache management.
|
|
|
|
Attributes:
|
|
block_size (int): Tokens per cache block
|
|
gpu_memory_utilization (float): GPU memory usage fraction (0-1)
|
|
cache_dtype (str): Data type for cache (default: 'bfloat16')
|
|
num_gpu_blocks_override (Optional[int]): Manual GPU blocks override
|
|
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
|
enc_dec_block_num (int): Encoder-decoder blocks count
|
|
enable_prefix_caching (bool): Prefix caching enable flag
|
|
total_block_num (int): Total available blocks
|
|
prefill_kvcache_block_num (int): Blocks allocated for prefill
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
block_size: int,
|
|
gpu_memory_utilization: float,
|
|
cache_dtype: str = "bfloat16",
|
|
num_gpu_blocks_override: Optional[int] = None,
|
|
kv_cache_ratio: float = 0.75,
|
|
enc_dec_block_num: int = 2,
|
|
enable_prefix_caching: bool = False,
|
|
):
|
|
"""
|
|
Initialize cache configuration.
|
|
|
|
Args:
|
|
block_size (int): Tokens per cache block
|
|
gpu_memory_utilization (float): GPU memory usage target (0-1)
|
|
cache_dtype (str): Cache data type (default: 'bfloat16')
|
|
num_gpu_blocks_override (Optional[int]): Manual GPU blocks setting
|
|
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
|
enc_dec_block_num (int): Encoder-decoder blocks count
|
|
enable_prefix_caching (bool): Enable prefix sharing
|
|
"""
|
|
self.block_size = block_size
|
|
self.gpu_memory_utilization = gpu_memory_utilization
|
|
self.num_gpu_blocks_override = num_gpu_blocks_override
|
|
self.kv_cache_ratio = kv_cache_ratio
|
|
self.enc_dec_block_num = enc_dec_block_num
|
|
self.cache_dtype = cache_dtype
|
|
self.enable_prefix_caching = enable_prefix_caching
|
|
self._verify_args()
|
|
|
|
def metrics_info(self):
|
|
"""
|
|
Convert config to metrics dictionary.
|
|
|
|
Returns:
|
|
Dict[str, str]: Key-value pairs of all config attributes
|
|
"""
|
|
return {key: str(value) for key, value in self.__dict__.items()}
|
|
|
|
def _verify_args(self):
|
|
"""Validate configuration arguments."""
|
|
if self.gpu_memory_utilization > 1.0:
|
|
raise ValueError(
|
|
"GPU memory utilization must be less than 1.0. Got "
|
|
f"{self.gpu_memory_utilization}.")
|
|
if self.kv_cache_ratio > 1.0:
|
|
raise ValueError("KV cache ratio must be less than 1.0. Got "
|
|
f"{self.kv_cache_ratio}.")
|
|
|
|
def postprocess(self, num_total_tokens, number_of_tasks):
|
|
"""
|
|
Calculate block allocation based on tokens and tasks.
|
|
|
|
Args:
|
|
num_total_tokens (int): Total tokens to process
|
|
number_of_tasks (int): Number of parallel tasks
|
|
|
|
Sets:
|
|
dec_token_num (int): Decoder tokens per block
|
|
total_block_num (int): Total blocks needed
|
|
prefill_kvcache_block_num (int): Blocks for prefill phase
|
|
"""
|
|
self.dec_token_num = self.enc_dec_block_num * self.block_size
|
|
if self.num_gpu_blocks_override is not None:
|
|
self.total_block_num = self.num_gpu_blocks_override
|
|
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
|
else:
|
|
length = num_total_tokens // number_of_tasks
|
|
block_num = (length + self.block_size - 1 + self.enc_dec_block_num) // self.block_size
|
|
self.total_block_num = block_num * number_of_tasks
|
|
self.prefill_kvcache_block_num= self.total_block_num
|
|
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
|
|
|
|
def reset(self, num_gpu_blocks):
|
|
"""
|
|
Reset GPU block allocation.
|
|
|
|
Args:
|
|
num_gpu_blocks (int): New total blocks count
|
|
|
|
Updates:
|
|
total_block_num (int)
|
|
prefill_kvcache_block_num (int)
|
|
"""
|
|
self.total_block_num = num_gpu_blocks
|
|
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
|
llm_logger.info((f"Reset block num, the total_block_num:{self.total_block_num},"
|
|
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
|
|
|
|
def print(self):
|
|
"""Print current cache configuration."""
|
|
llm_logger.info("Cache Configuration Information :")
|
|
for k, v in self.__dict__.items():
|
|
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
|
llm_logger.info(
|
|
"=============================================================")
|
|
|
|
|
|
class Config:
|
|
"""
|
|
Main engine configuration class combining all components.
|
|
|
|
Attributes:
|
|
model_config (ModelConfig): Model settings
|
|
cache_config (CacheConfig): Cache management settings
|
|
scheduler_config (SchedulerConfig): Task scheduling settings
|
|
model_name_or_path (str): Model identifier/path
|
|
tokenizer (str): Tokenizer identifier
|
|
tensor_parallel_size (int): Parallelism degree (default: 8)
|
|
nnode (int): Node count (default: 1)
|
|
max_model_len (int): Max sequence length (default: 8192)
|
|
max_num_seqs (int): Max concurrent sequences (default: 8)
|
|
max_num_batched_tokens (Optional[int]): Max batched tokens
|
|
pod_ips (Optional[List[str]]): Cluster node IPs
|
|
mm_processor_kwargs (Optional[Dict]): Multi-modal processor args
|
|
speculative_config (Optional[Dict]): Speculative execution settings
|
|
use_warmup (bool): Warmup enable flag
|
|
enable_mm (bool): Multi-modal enable flag
|
|
enable_chunked_prefill (bool): Chunked prefill enable flag
|
|
device_ids (str): GPU device IDs
|
|
tp_num_per_node (int): Tensor parallelism per node
|
|
host_ip (str): Current host IP
|
|
paddle_commit_id (str): PaddlePaddle version
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig,
|
|
cache_config: CacheConfig,
|
|
scheduler_config: SchedulerConfig,
|
|
model_name_or_path: str = None,
|
|
tokenizer: str = None,
|
|
tensor_parallel_size: int = 8,
|
|
nnode: int = 1,
|
|
max_model_len: int = 8192,
|
|
max_num_seqs: int = 8,
|
|
max_num_batched_tokens: Optional[int] = None,
|
|
pod_ips: Optional[List[str]] = None,
|
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
|
speculative_config: Optional[Dict[str, Any]] = None,
|
|
use_warmup: bool = False,
|
|
engine_worker_queue_port: int = 8002,
|
|
enable_mm: bool = False,
|
|
enable_chunked_prefill: bool = False,
|
|
):
|
|
"""
|
|
Initialize engine configuration.
|
|
|
|
Args:
|
|
model_config (ModelConfig): Model settings
|
|
cache_config (CacheConfig): Cache settings
|
|
scheduler_config (SchedulerConfig): Scheduler settings
|
|
model_name_or_path (str): Model identifier (default: None)
|
|
tokenizer (str): Tokenizer identifier (default: None)
|
|
tensor_parallel_size (int): Parallelism degree (default: 8)
|
|
nnode (int): Node count (default: 1)
|
|
max_model_len (int): Max sequence length (default: 8192)
|
|
max_num_seqs (int): Max concurrent sequences (default: 8)
|
|
max_num_batched_tokens (Optional[int]): Max batched tokens (default: None)
|
|
pod_ips (Optional[List[str]]): Cluster node IPs (default: None)
|
|
mm_processor_kwargs (Optional[Dict]): Multi-modal args (default: None)
|
|
speculative_config (Optional[Dict]): Speculative settings (default: None)
|
|
use_warmup (bool): Warmup flag (default: False)
|
|
engine_worker_queue_port (int): Worker queue port (default: 8002)
|
|
enable_mm (bool): Multi-modal flag (default: False)
|
|
enable_chunked_prefill (bool): Chunked prefill flag (default: False)
|
|
"""
|
|
self.model_config = model_config
|
|
self.cache_config = cache_config
|
|
self.scheduler_config = scheduler_config
|
|
self.model_name_or_path = model_name_or_path
|
|
self.tokenizer = tokenizer
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.tensor_parallel_size = tensor_parallel_size
|
|
self.nnode = nnode
|
|
self.pod_ips = pod_ips
|
|
self.max_model_len = max_model_len
|
|
self.max_num_seqs = max_num_seqs
|
|
self.mm_processor_kwargs = mm_processor_kwargs
|
|
self.enable_mm = enable_mm
|
|
self.speculative_config = speculative_config
|
|
self.use_warmup = use_warmup
|
|
self.enable_chunked_prefill = enable_chunked_prefill
|
|
|
|
# TODO
|
|
self.max_prefill_batch = 3
|
|
if enable_mm:
|
|
self.max_prefill_batch = 1 # TODO: Currently multi-modal prefill only supports parallelism=1 (needs optimization)
|
|
|
|
self.engine_worker_queue_port = engine_worker_queue_port
|
|
self.device_ids = ",".join(
|
|
[str(i) for i in range(self.tensor_parallel_size)])
|
|
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
|
|
|
self.read_from_config()
|
|
self.postprocess()
|
|
self.check()
|
|
self.print()
|
|
|
|
def postprocess(self):
|
|
"""
|
|
Calculate derived parameters:
|
|
- Validates GPU device count matches tensor_parallel_size
|
|
- Computes tensor parallelism per node
|
|
- Gets host IP and Paddle version
|
|
- Sets default max_num_batched_tokens if not provided
|
|
- Initializes cache configuration
|
|
"""
|
|
if len(self.device_ids.split(',')) > self.tensor_parallel_size:
|
|
self.device_ids = ",".join(
|
|
self.device_ids.split(',')[:self.tensor_parallel_size:])
|
|
assert len(
|
|
self.device_ids.split(',')
|
|
) == self.tensor_parallel_size, f"The number of available GPUs is {len(self.device_ids.split(','))}, which is less than the tensor parallel required {self.tensor_parallel_size}."
|
|
|
|
assert self.tensor_parallel_size % self.nnode == 0, f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
|
self.tp_num_per_node = self.tensor_parallel_size // self.nnode
|
|
self.host_ip = get_host_ip()
|
|
|
|
import paddle
|
|
self.paddle_commit_id = paddle.version.commit
|
|
|
|
if self.max_num_batched_tokens is None:
|
|
if self.enable_chunked_prefill:
|
|
self.max_num_batched_tokens = 2048
|
|
else:
|
|
self.max_num_batched_tokens = self.max_model_len
|
|
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
|
|
|
|
|
def check(self):
|
|
"""
|
|
Validate configuration values:
|
|
- max_num_seqs <= 256
|
|
- engine_worker_queue_port available
|
|
- 1 <= tensor_parallel_size <= 8
|
|
- nnode >= 1
|
|
- max_model_len >= 16
|
|
- max_num_seqs >= 1
|
|
- Validates scheduler configuration
|
|
"""
|
|
assert (
|
|
self.max_num_seqs <= 256
|
|
), "The parameter `max_num_seqs` is not allowed to exceed 256, " "but now it's {}.".format(
|
|
self.max_num_seqs)
|
|
assert (
|
|
is_port_available('0.0.0.0', self.engine_worker_queue_port)
|
|
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
|
assert (
|
|
8 >= self.tensor_parallel_size > 0
|
|
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and 8"
|
|
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
|
|
assert (
|
|
self.max_model_len >= 16
|
|
), f"max_model_len: {self.max_model_len} should be larger than 16"
|
|
assert (
|
|
self.max_num_seqs
|
|
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
|
|
|
self.scheduler_config.check()
|
|
|
|
def print(self, file=None):
|
|
"""
|
|
Print or save current configuration.
|
|
|
|
Args:
|
|
file (Optional[str]): File path to save config (default: None)
|
|
"""
|
|
llm_logger.info(
|
|
"=================== Configuration Information ===============")
|
|
for k, v in self.__dict__.items():
|
|
if k == "generation_config" and v is not None:
|
|
for gck, gcv in v.to_dict().items():
|
|
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
|
elif k == "cache_config" or k == "model_config" or k == "scheduler_config":
|
|
v.print()
|
|
else:
|
|
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
|
llm_logger.info(
|
|
"=============================================================")
|
|
if file is not None:
|
|
f = open(file, "a")
|
|
now_time = datetime.now()
|
|
f.write(f"{now_time} configuration information as below,\n")
|
|
for k, v in self.__dict__.items():
|
|
f.write("{:<20}:{:<6}{}\n".format(k, "", v))
|
|
f.close()
|
|
|
|
def read_from_config(self):
|
|
"""
|
|
Update configuration from model JSON file.
|
|
Handles special cases:
|
|
- infer_model_block_size -> block_size
|
|
- return_full_hidden_states
|
|
- infer_model_dtype -> cache_dtype
|
|
"""
|
|
|
|
def reset_value(cls, value_name, key):
|
|
if hasattr(cls, key):
|
|
value = getattr(cls, key)
|
|
setattr(cls, value_name, value)
|
|
llm_logger.info(
|
|
f"Reset parameter {value_name} = {value} from configuration."
|
|
)
|
|
|
|
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
|
reset_value(self.model_config, "return_full_hidden_states", "return_full_hidden_states")
|
|
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps(self.__dict__, indent=4)
|
|
|