Unify server-side and model-side Config(Part-5) (#3497)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* move config

* fix xpu

* fix

* fix vl

* fix vl

* fix unitest

* fix args

* add unitest

* fix test
This commit is contained in:
YuanRisheng
2025-08-21 19:00:21 +08:00
committed by GitHub
parent e5aa7087db
commit c389a4013c
15 changed files with 480 additions and 499 deletions

View File

@@ -18,17 +18,20 @@ from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
import paddle
from paddleformers.transformers.configuration_utils import PretrainedConfig
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.utils import check_unified_ckpt, get_logger
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
logger = get_logger("config", "config.log")
@@ -120,7 +123,6 @@ class ModelConfig:
self.max_model_len = 0
self.dtype = ""
self.enable_logprob = False
self.enable_mm = False
self.enable_redundant_experts = False
self.redundant_experts_num = 0
self.seed = 0
@@ -154,6 +156,12 @@ class ModelConfig:
if ErnieArchitectures.contains_ernie_arch(self.architectures):
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
architectures = self.architectures[0]
if MultimodalRegistry.contains_model(architectures):
self.enable_mm = True
else:
self.enable_mm = False
self.is_unified_ckpt = check_unified_ckpt(self.model)
self.override_name_from_config()
@@ -934,19 +942,53 @@ class FDConfig:
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig = field(default=None, init=True) # type: ignore
def __init__(
self,
model_config: ModelConfig = None,
cache_config: CacheConfig = None,
parallel_config: ParallelConfig = None,
load_config: LoadConfig = None,
commit_config: CommitConfig = CommitConfig(),
scheduler_config: SchedulerConfig = None,
device_config: DeviceConfig = None,
decoding_config: DecodingConfig = None,
quant_config: QuantConfigBase = None,
graph_opt_config: GraphOptimizationConfig = None,
speculative_config: SpeculativeConfig = None,
tokenizer: str = None,
max_model_len: int = 8192,
max_num_seqs: int = 8,
max_num_batched_tokens: Optional[int] = None,
ips: str = None,
use_warmup: bool = False,
engine_worker_queue_port: int = 8002,
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
splitwise_role: str = "mixed",
innode_prefill_ports: Optional[List[int]] = None,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
reasoning_parser: str = None,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
early_stop_config: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
test_mode=False,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
self.scheduler_config: SchedulerConfig = scheduler_config # type: ignore
self.parallel_config = parallel_config # type: ignore
self.speculative_config: SpeculativeConfig = speculative_config
self.device_config: DeviceConfig = device_config # type: ignore
self.load_config: LoadConfig = load_config
self.quant_config: Optional[QuantConfigBase] = quant_config
self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.decoding_config: DecodingConfig = decoding_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
parallel_config: ParallelConfig = field(default=None, init=True)
speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore
device_config: DeviceConfig = field(default=None, init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
early_stop_config: Optional[EarlyStopConfig] = None
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
def __post_init__(self):
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
@@ -955,3 +997,278 @@ class FDConfig:
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:
self.graph_opt_config.graph_opt_level = 1
self.tokenizer = tokenizer
self.max_num_batched_tokens = max_num_batched_tokens
self.ips = ips
self.tool_parser = tool_parser
if self.ips is None:
self.master_ip = "0.0.0.0"
elif isinstance(self.ips, list):
self.master_ip = self.ips[0]
else:
self.ips = self.ips.split(",")
self.master_ip = self.ips[0]
if self.ips is None:
self.nnode = 1
self.node_rank = 0
else:
self.nnode = len(self.ips)
for idx, ip in enumerate(self.ips):
if ip == self.master_ip:
self.node_rank = idx
self.max_model_len = max_model_len
self.max_num_seqs = max_num_seqs
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.use_warmup = use_warmup
self.splitwise_role = splitwise_role
self.innode_prefill_ports = innode_prefill_ports
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int)
# TODO
self.max_prefill_batch = 3
if current_platform.is_xpu():
self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm:
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node:
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else:
self.worker_num_per_node = num_ranks
self.engine_worker_queue_port = engine_worker_queue_port
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
if current_platform.is_xpu():
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
self.read_from_config()
self.postprocess()
if test_mode:
return
self.check()
self.print()
def postprocess(self):
"""
calculate some parameters
"""
assert (
self.device_ids.split(",").__len__() == self.worker_num_per_node
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size]
self.host_ip = get_host_ip()
if self.ips is None or self.host_ip == self.master_ip:
self.is_master = True
else:
self.is_master = False
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node:
self.is_master = True
self.paddle_commit_id = paddle.version.commit
if self.max_num_batched_tokens is None:
if self.cache_config.enable_chunked_prefill:
self.max_num_batched_tokens = 2048
else:
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
if self.guided_decoding_backend == "auto":
if self.model_config.enable_mm:
self.guided_decoding_backend = "off"
else:
self.guided_decoding_backend = "xgrammar"
def check(self):
"""
check the legality of config
"""
assert self.max_num_seqs <= 256, (
"The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
)
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
assert self.max_num_batched_tokens >= self.max_num_seqs, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
)
assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
)
assert (
self.max_num_partial_prefills >= 1
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
assert (
self.max_long_partial_prefills >= 1
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
)
assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
), "TP and EP cannot be enabled at the same time"
if not self.cache_config.enable_chunked_prefill:
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}"
)
else:
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
)
if self.max_num_partial_prefills > 1:
assert (
self.cache_config.enable_chunked_prefill is True
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
assert self.long_prefill_token_threshold < self.max_model_len, (
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
f" max_model_len: {self.max_model_len}"
)
if self.guided_decoding_backend is not None:
assert self.guided_decoding_backend in [
"xgrammar",
"XGrammar",
"auto",
"off",
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
if self.guided_decoding_backend != "off":
# TODO: mm support guided_decoding
assert (
self.model_config.enable_mm is False
), "Multimodal model currently do not support guided_decoding"
# TODO: speculative decoding support guided_decoding
# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
try:
import xgrammar # noqa
except Exception as e:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
)
if self.scheduler_config is not None:
self.scheduler_config.check()
def print(self):
"""
print all config
"""
logger.info("=================== Configuration Information ===============")
for k, v in self.__dict__.items():
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif (
k == "cache_config"
or k == "model_config"
or k == "scheduler_config"
or k == "parallel_config"
or k == "commit_config"
):
if v is not None:
v.print()
else:
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def init_cache_info(self):
"""
initialize cache info
"""
disaggregate_info = {}
if self.splitwise_role != "mixed":
disaggregate_info["role"] = self.splitwise_role
disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.engine_worker_queue_port,
"device_ids": self.local_device_ids,
}
elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.cache_config.pd_comm_port[0],
"rdma_port": self.cache_config.rdma_comm_ports,
}
self.disaggregate_info = disaggregate_info
logger.info(f"disaggregate_info: {self.disaggregate_info}")
def read_from_config(self):
"""
reset model config from json file
"""
def reset_value(cls, value_name, key):
if hasattr(cls, key):
value = getattr(cls, key)
setattr(cls, value_name, value)
logger.info(f"Reset parameter {value_name} = {value} from configuration.")
reset_value(self.cache_config, "block_size", "infer_model_block_size")
reset_value(
self.model_config,
"return_full_hidden_states",
"return_full_hidden_states",
)
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def _check_master(self):
return self.is_master
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)
if type(val) is str:
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
else:
setattr(self, attr_name, val)
def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)