mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 23:02:36 +08:00
Unify server-side and model-side Config(Part-5) (#3497)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* move config * fix xpu * fix * fix vl * fix vl * fix unitest * fix args * add unitest * fix test
This commit is contained in:
@@ -18,17 +18,20 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import paddle
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import check_unified_ckpt, get_logger
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
||||
|
||||
logger = get_logger("config", "config.log")
|
||||
|
||||
@@ -120,7 +123,6 @@ class ModelConfig:
|
||||
self.max_model_len = 0
|
||||
self.dtype = ""
|
||||
self.enable_logprob = False
|
||||
self.enable_mm = False
|
||||
self.enable_redundant_experts = False
|
||||
self.redundant_experts_num = 0
|
||||
self.seed = 0
|
||||
@@ -154,6 +156,12 @@ class ModelConfig:
|
||||
if ErnieArchitectures.contains_ernie_arch(self.architectures):
|
||||
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
||||
|
||||
architectures = self.architectures[0]
|
||||
if MultimodalRegistry.contains_model(architectures):
|
||||
self.enable_mm = True
|
||||
else:
|
||||
self.enable_mm = False
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model)
|
||||
|
||||
self.override_name_from_config()
|
||||
@@ -934,19 +942,53 @@ class FDConfig:
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
parallel_config: ParallelConfig = None,
|
||||
load_config: LoadConfig = None,
|
||||
commit_config: CommitConfig = CommitConfig(),
|
||||
scheduler_config: SchedulerConfig = None,
|
||||
device_config: DeviceConfig = None,
|
||||
decoding_config: DecodingConfig = None,
|
||||
quant_config: QuantConfigBase = None,
|
||||
graph_opt_config: GraphOptimizationConfig = None,
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
tokenizer: str = None,
|
||||
max_model_len: int = 8192,
|
||||
max_num_seqs: int = 8,
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
ips: str = None,
|
||||
use_warmup: bool = False,
|
||||
engine_worker_queue_port: int = 8002,
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
splitwise_role: str = "mixed",
|
||||
innode_prefill_ports: Optional[List[int]] = None,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
reasoning_parser: str = None,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
test_mode=False,
|
||||
):
|
||||
self.model_config: ModelConfig = model_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.scheduler_config: SchedulerConfig = scheduler_config # type: ignore
|
||||
self.parallel_config = parallel_config # type: ignore
|
||||
self.speculative_config: SpeculativeConfig = speculative_config
|
||||
self.device_config: DeviceConfig = device_config # type: ignore
|
||||
self.load_config: LoadConfig = load_config
|
||||
self.quant_config: Optional[QuantConfigBase] = quant_config
|
||||
self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config
|
||||
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
|
||||
parallel_config: ParallelConfig = field(default=None, init=True)
|
||||
speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore
|
||||
device_config: DeviceConfig = field(default=None, init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True)
|
||||
quant_config: Optional[QuantConfigBase] = None
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
early_stop_config: Optional[EarlyStopConfig] = None
|
||||
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
@@ -955,3 +997,278 @@ class FDConfig:
|
||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||
if self.graph_opt_config.graph_opt_level == 2:
|
||||
self.graph_opt_config.graph_opt_level = 1
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.ips = ips
|
||||
self.tool_parser = tool_parser
|
||||
|
||||
if self.ips is None:
|
||||
self.master_ip = "0.0.0.0"
|
||||
elif isinstance(self.ips, list):
|
||||
self.master_ip = self.ips[0]
|
||||
else:
|
||||
self.ips = self.ips.split(",")
|
||||
self.master_ip = self.ips[0]
|
||||
|
||||
if self.ips is None:
|
||||
self.nnode = 1
|
||||
self.node_rank = 0
|
||||
else:
|
||||
self.nnode = len(self.ips)
|
||||
|
||||
for idx, ip in enumerate(self.ips):
|
||||
if ip == self.master_ip:
|
||||
self.node_rank = idx
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.use_warmup = use_warmup
|
||||
self.splitwise_role = splitwise_role
|
||||
self.innode_prefill_ports = innode_prefill_ports
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self._str_to_list("innode_prefill_ports", int)
|
||||
|
||||
# TODO
|
||||
self.max_prefill_batch = 3
|
||||
if current_platform.is_xpu():
|
||||
self.max_prefill_batch = 1
|
||||
if self.model_config is not None and self.model_config.enable_mm:
|
||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node:
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
else:
|
||||
self.worker_num_per_node = num_ranks
|
||||
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
if current_platform.is_xpu():
|
||||
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
self.print()
|
||||
|
||||
def postprocess(self):
|
||||
"""
|
||||
calculate some parameters
|
||||
"""
|
||||
assert (
|
||||
self.device_ids.split(",").__len__() == self.worker_num_per_node
|
||||
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
||||
|
||||
self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size]
|
||||
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
if self.ips is None or self.host_ip == self.master_ip:
|
||||
self.is_master = True
|
||||
else:
|
||||
self.is_master = False
|
||||
|
||||
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node:
|
||||
self.is_master = True
|
||||
|
||||
self.paddle_commit_id = paddle.version.commit
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
||||
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
||||
|
||||
if self.guided_decoding_backend == "auto":
|
||||
if self.model_config.enable_mm:
|
||||
self.guided_decoding_backend = "off"
|
||||
else:
|
||||
self.guided_decoding_backend = "xgrammar"
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check the legality of config
|
||||
"""
|
||||
assert self.max_num_seqs <= 256, (
|
||||
"The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
|
||||
)
|
||||
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
|
||||
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
|
||||
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
||||
assert self.max_num_batched_tokens >= self.max_num_seqs, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
|
||||
)
|
||||
assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
|
||||
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
|
||||
)
|
||||
assert (
|
||||
self.max_num_partial_prefills >= 1
|
||||
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
|
||||
|
||||
assert (
|
||||
self.max_long_partial_prefills >= 1
|
||||
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
|
||||
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
|
||||
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
|
||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||
)
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
|
||||
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
|
||||
), "TP and EP cannot be enabled at the same time"
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
)
|
||||
else:
|
||||
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
assert (
|
||||
self.cache_config.enable_chunked_prefill is True
|
||||
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
|
||||
assert self.long_prefill_token_threshold < self.max_model_len, (
|
||||
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
|
||||
f" max_model_len: {self.max_model_len}"
|
||||
)
|
||||
|
||||
if self.guided_decoding_backend is not None:
|
||||
assert self.guided_decoding_backend in [
|
||||
"xgrammar",
|
||||
"XGrammar",
|
||||
"auto",
|
||||
"off",
|
||||
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
||||
|
||||
if self.guided_decoding_backend != "off":
|
||||
# TODO: mm support guided_decoding
|
||||
assert (
|
||||
self.model_config.enable_mm is False
|
||||
), "Multimodal model currently do not support guided_decoding"
|
||||
|
||||
# TODO: speculative decoding support guided_decoding
|
||||
|
||||
# TODO: xpu support guided_decoding
|
||||
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
import xgrammar # noqa
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
)
|
||||
if self.scheduler_config is not None:
|
||||
self.scheduler_config.check()
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
"""
|
||||
logger.info("=================== Configuration Information ===============")
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif (
|
||||
k == "cache_config"
|
||||
or k == "model_config"
|
||||
or k == "scheduler_config"
|
||||
or k == "parallel_config"
|
||||
or k == "commit_config"
|
||||
):
|
||||
if v is not None:
|
||||
v.print()
|
||||
else:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def init_cache_info(self):
|
||||
"""
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
if self.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.cache_config.pd_comm_port[0],
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
self.disaggregate_info = disaggregate_info
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
reset model config from json file
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key):
|
||||
if hasattr(cls, key):
|
||||
value = getattr(cls, key)
|
||||
setattr(cls, value_name, value)
|
||||
logger.info(f"Reset parameter {value_name} = {value} from configuration.")
|
||||
|
||||
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
||||
reset_value(
|
||||
self.model_config,
|
||||
"return_full_hidden_states",
|
||||
"return_full_hidden_states",
|
||||
)
|
||||
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
||||
|
||||
def _check_master(self):
|
||||
return self.is_master
|
||||
|
||||
def _str_to_list(self, attr_name, default_type):
|
||||
if hasattr(self, attr_name):
|
||||
val = getattr(self, attr_name)
|
||||
if type(val) is str:
|
||||
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
|
||||
else:
|
||||
setattr(self, attr_name, val)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.__dict__, indent=4)
|
||||
|
Reference in New Issue
Block a user