mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Unify server-side and model-side Config(Part-5) (#3497)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* move config * fix xpu * fix * fix vl * fix vl * fix unitest * fix args * add unitest * fix test
This commit is contained in:
@@ -18,17 +18,20 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
||||||
|
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import check_unified_ckpt, get_logger
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
|
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
||||||
|
|
||||||
logger = get_logger("config", "config.log")
|
logger = get_logger("config", "config.log")
|
||||||
|
|
||||||
@@ -120,7 +123,6 @@ class ModelConfig:
|
|||||||
self.max_model_len = 0
|
self.max_model_len = 0
|
||||||
self.dtype = ""
|
self.dtype = ""
|
||||||
self.enable_logprob = False
|
self.enable_logprob = False
|
||||||
self.enable_mm = False
|
|
||||||
self.enable_redundant_experts = False
|
self.enable_redundant_experts = False
|
||||||
self.redundant_experts_num = 0
|
self.redundant_experts_num = 0
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@@ -154,6 +156,12 @@ class ModelConfig:
|
|||||||
if ErnieArchitectures.contains_ernie_arch(self.architectures):
|
if ErnieArchitectures.contains_ernie_arch(self.architectures):
|
||||||
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
||||||
|
|
||||||
|
architectures = self.architectures[0]
|
||||||
|
if MultimodalRegistry.contains_model(architectures):
|
||||||
|
self.enable_mm = True
|
||||||
|
else:
|
||||||
|
self.enable_mm = False
|
||||||
|
|
||||||
self.is_unified_ckpt = check_unified_ckpt(self.model)
|
self.is_unified_ckpt = check_unified_ckpt(self.model)
|
||||||
|
|
||||||
self.override_name_from_config()
|
self.override_name_from_config()
|
||||||
@@ -934,19 +942,53 @@ class FDConfig:
|
|||||||
simplifies passing around the distinct configurations in the codebase.
|
simplifies passing around the distinct configurations in the codebase.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig = None,
|
||||||
|
cache_config: CacheConfig = None,
|
||||||
|
parallel_config: ParallelConfig = None,
|
||||||
|
load_config: LoadConfig = None,
|
||||||
|
commit_config: CommitConfig = CommitConfig(),
|
||||||
|
scheduler_config: SchedulerConfig = None,
|
||||||
|
device_config: DeviceConfig = None,
|
||||||
|
decoding_config: DecodingConfig = None,
|
||||||
|
quant_config: QuantConfigBase = None,
|
||||||
|
graph_opt_config: GraphOptimizationConfig = None,
|
||||||
|
speculative_config: SpeculativeConfig = None,
|
||||||
|
tokenizer: str = None,
|
||||||
|
max_model_len: int = 8192,
|
||||||
|
max_num_seqs: int = 8,
|
||||||
|
max_num_batched_tokens: Optional[int] = None,
|
||||||
|
ips: str = None,
|
||||||
|
use_warmup: bool = False,
|
||||||
|
engine_worker_queue_port: int = 8002,
|
||||||
|
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
splitwise_role: str = "mixed",
|
||||||
|
innode_prefill_ports: Optional[List[int]] = None,
|
||||||
|
max_num_partial_prefills: int = 1,
|
||||||
|
max_long_partial_prefills: int = 1,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
reasoning_parser: str = None,
|
||||||
|
guided_decoding_backend: Optional[str] = None,
|
||||||
|
disable_any_whitespace: bool = False,
|
||||||
|
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||||
|
tool_parser: str = None,
|
||||||
|
test_mode=False,
|
||||||
|
):
|
||||||
|
self.model_config: ModelConfig = model_config # type: ignore
|
||||||
|
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||||
|
self.scheduler_config: SchedulerConfig = scheduler_config # type: ignore
|
||||||
|
self.parallel_config = parallel_config # type: ignore
|
||||||
|
self.speculative_config: SpeculativeConfig = speculative_config
|
||||||
|
self.device_config: DeviceConfig = device_config # type: ignore
|
||||||
|
self.load_config: LoadConfig = load_config
|
||||||
|
self.quant_config: Optional[QuantConfigBase] = quant_config
|
||||||
|
self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config
|
||||||
|
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||||
|
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||||
|
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||||
|
|
||||||
parallel_config: ParallelConfig = field(default=None, init=True)
|
|
||||||
speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore
|
|
||||||
device_config: DeviceConfig = field(default=None, init=True) # type: ignore
|
|
||||||
load_config: LoadConfig = field(default=None, init=True)
|
|
||||||
quant_config: Optional[QuantConfigBase] = None
|
|
||||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
|
||||||
early_stop_config: Optional[EarlyStopConfig] = None
|
|
||||||
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
|
|
||||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||||
@@ -955,3 +997,278 @@ class FDConfig:
|
|||||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||||
if self.graph_opt_config.graph_opt_level == 2:
|
if self.graph_opt_config.graph_opt_level == 2:
|
||||||
self.graph_opt_config.graph_opt_level = 1
|
self.graph_opt_config.graph_opt_level = 1
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
self.ips = ips
|
||||||
|
self.tool_parser = tool_parser
|
||||||
|
|
||||||
|
if self.ips is None:
|
||||||
|
self.master_ip = "0.0.0.0"
|
||||||
|
elif isinstance(self.ips, list):
|
||||||
|
self.master_ip = self.ips[0]
|
||||||
|
else:
|
||||||
|
self.ips = self.ips.split(",")
|
||||||
|
self.master_ip = self.ips[0]
|
||||||
|
|
||||||
|
if self.ips is None:
|
||||||
|
self.nnode = 1
|
||||||
|
self.node_rank = 0
|
||||||
|
else:
|
||||||
|
self.nnode = len(self.ips)
|
||||||
|
|
||||||
|
for idx, ip in enumerate(self.ips):
|
||||||
|
if ip == self.master_ip:
|
||||||
|
self.node_rank = idx
|
||||||
|
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
self.max_num_seqs = max_num_seqs
|
||||||
|
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||||
|
self.mm_processor_kwargs = mm_processor_kwargs
|
||||||
|
self.use_warmup = use_warmup
|
||||||
|
self.splitwise_role = splitwise_role
|
||||||
|
self.innode_prefill_ports = innode_prefill_ports
|
||||||
|
self.max_num_partial_prefills = max_num_partial_prefills
|
||||||
|
self.max_long_partial_prefills = max_long_partial_prefills
|
||||||
|
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||||
|
self.reasoning_parser = reasoning_parser
|
||||||
|
self.guided_decoding_backend = guided_decoding_backend
|
||||||
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
|
self._str_to_list("innode_prefill_ports", int)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
self.max_prefill_batch = 3
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
self.max_prefill_batch = 1
|
||||||
|
if self.model_config is not None and self.model_config.enable_mm:
|
||||||
|
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||||
|
|
||||||
|
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||||
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
|
if num_ranks > self.max_chips_per_node:
|
||||||
|
self.worker_num_per_node = self.max_chips_per_node
|
||||||
|
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||||
|
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||||
|
else:
|
||||||
|
self.worker_num_per_node = num_ranks
|
||||||
|
|
||||||
|
self.engine_worker_queue_port = engine_worker_queue_port
|
||||||
|
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
|
||||||
|
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
|
||||||
|
|
||||||
|
self.read_from_config()
|
||||||
|
self.postprocess()
|
||||||
|
if test_mode:
|
||||||
|
return
|
||||||
|
self.check()
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
"""
|
||||||
|
calculate some parameters
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
self.device_ids.split(",").__len__() == self.worker_num_per_node
|
||||||
|
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
||||||
|
|
||||||
|
self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size]
|
||||||
|
|
||||||
|
self.host_ip = get_host_ip()
|
||||||
|
|
||||||
|
if self.ips is None or self.host_ip == self.master_ip:
|
||||||
|
self.is_master = True
|
||||||
|
else:
|
||||||
|
self.is_master = False
|
||||||
|
|
||||||
|
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node:
|
||||||
|
self.is_master = True
|
||||||
|
|
||||||
|
self.paddle_commit_id = paddle.version.commit
|
||||||
|
|
||||||
|
if self.max_num_batched_tokens is None:
|
||||||
|
if self.cache_config.enable_chunked_prefill:
|
||||||
|
self.max_num_batched_tokens = 2048
|
||||||
|
else:
|
||||||
|
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||||
|
self.max_num_batched_tokens = self.max_model_len
|
||||||
|
else:
|
||||||
|
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||||
|
|
||||||
|
if self.long_prefill_token_threshold == 0:
|
||||||
|
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||||
|
|
||||||
|
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
||||||
|
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
||||||
|
|
||||||
|
if self.guided_decoding_backend == "auto":
|
||||||
|
if self.model_config.enable_mm:
|
||||||
|
self.guided_decoding_backend = "off"
|
||||||
|
else:
|
||||||
|
self.guided_decoding_backend = "xgrammar"
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
"""
|
||||||
|
check the legality of config
|
||||||
|
"""
|
||||||
|
assert self.max_num_seqs <= 256, (
|
||||||
|
"The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
|
||||||
|
)
|
||||||
|
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
|
||||||
|
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
|
||||||
|
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
||||||
|
assert self.max_num_batched_tokens >= self.max_num_seqs, (
|
||||||
|
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||||
|
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
|
||||||
|
)
|
||||||
|
assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
|
||||||
|
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
|
||||||
|
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.max_num_partial_prefills >= 1
|
||||||
|
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.max_long_partial_prefills >= 1
|
||||||
|
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
|
||||||
|
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
|
||||||
|
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
|
||||||
|
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||||
|
)
|
||||||
|
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||||
|
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||||
|
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
|
||||||
|
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
|
||||||
|
), "TP and EP cannot be enabled at the same time"
|
||||||
|
|
||||||
|
if not self.cache_config.enable_chunked_prefill:
|
||||||
|
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||||
|
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||||
|
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||||
|
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
|
||||||
|
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||||
|
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.max_num_partial_prefills > 1:
|
||||||
|
assert (
|
||||||
|
self.cache_config.enable_chunked_prefill is True
|
||||||
|
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
|
||||||
|
assert self.long_prefill_token_threshold < self.max_model_len, (
|
||||||
|
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
|
||||||
|
f" max_model_len: {self.max_model_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.guided_decoding_backend is not None:
|
||||||
|
assert self.guided_decoding_backend in [
|
||||||
|
"xgrammar",
|
||||||
|
"XGrammar",
|
||||||
|
"auto",
|
||||||
|
"off",
|
||||||
|
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
||||||
|
|
||||||
|
if self.guided_decoding_backend != "off":
|
||||||
|
# TODO: mm support guided_decoding
|
||||||
|
assert (
|
||||||
|
self.model_config.enable_mm is False
|
||||||
|
), "Multimodal model currently do not support guided_decoding"
|
||||||
|
|
||||||
|
# TODO: speculative decoding support guided_decoding
|
||||||
|
|
||||||
|
# TODO: xpu support guided_decoding
|
||||||
|
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xgrammar # noqa
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||||
|
)
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
self.scheduler_config.check()
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
"""
|
||||||
|
print all config
|
||||||
|
"""
|
||||||
|
logger.info("=================== Configuration Information ===============")
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if k == "generation_config" and v is not None:
|
||||||
|
for gck, gcv in v.to_dict().items():
|
||||||
|
logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||||
|
elif (
|
||||||
|
k == "cache_config"
|
||||||
|
or k == "model_config"
|
||||||
|
or k == "scheduler_config"
|
||||||
|
or k == "parallel_config"
|
||||||
|
or k == "commit_config"
|
||||||
|
):
|
||||||
|
if v is not None:
|
||||||
|
v.print()
|
||||||
|
else:
|
||||||
|
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||||
|
logger.info("=============================================================")
|
||||||
|
|
||||||
|
def init_cache_info(self):
|
||||||
|
"""
|
||||||
|
initialize cache info
|
||||||
|
"""
|
||||||
|
disaggregate_info = {}
|
||||||
|
if self.splitwise_role != "mixed":
|
||||||
|
disaggregate_info["role"] = self.splitwise_role
|
||||||
|
disaggregate_info["cache_info"] = dict()
|
||||||
|
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
||||||
|
disaggregate_info["transfer_protocol"] = current_protocol
|
||||||
|
for protocol in current_protocol:
|
||||||
|
if protocol == "ipc":
|
||||||
|
disaggregate_info["cache_info"][protocol] = {
|
||||||
|
"ip": self.host_ip,
|
||||||
|
"port": self.engine_worker_queue_port,
|
||||||
|
"device_ids": self.local_device_ids,
|
||||||
|
}
|
||||||
|
elif protocol == "rdma":
|
||||||
|
disaggregate_info["cache_info"][protocol] = {
|
||||||
|
"ip": self.host_ip,
|
||||||
|
"port": self.cache_config.pd_comm_port[0],
|
||||||
|
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||||
|
}
|
||||||
|
self.disaggregate_info = disaggregate_info
|
||||||
|
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||||
|
|
||||||
|
def read_from_config(self):
|
||||||
|
"""
|
||||||
|
reset model config from json file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reset_value(cls, value_name, key):
|
||||||
|
if hasattr(cls, key):
|
||||||
|
value = getattr(cls, key)
|
||||||
|
setattr(cls, value_name, value)
|
||||||
|
logger.info(f"Reset parameter {value_name} = {value} from configuration.")
|
||||||
|
|
||||||
|
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
||||||
|
reset_value(
|
||||||
|
self.model_config,
|
||||||
|
"return_full_hidden_states",
|
||||||
|
"return_full_hidden_states",
|
||||||
|
)
|
||||||
|
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
||||||
|
|
||||||
|
def _check_master(self):
|
||||||
|
return self.is_master
|
||||||
|
|
||||||
|
def _str_to_list(self, attr_name, default_type):
|
||||||
|
if hasattr(self, attr_name):
|
||||||
|
val = getattr(self, attr_name)
|
||||||
|
if type(val) is str:
|
||||||
|
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
|
||||||
|
else:
|
||||||
|
setattr(self, attr_name, val)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return json.dumps(self.__dict__, indent=4)
|
||||||
|
@@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
EarlyStopConfig,
|
EarlyStopConfig,
|
||||||
|
FDConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
LoadConfig,
|
LoadConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@@ -30,10 +31,13 @@ from fastdeploy.config import (
|
|||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
TaskOption,
|
TaskOption,
|
||||||
)
|
)
|
||||||
from fastdeploy.engine.config import Config
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.scheduler.config import SchedulerConfig
|
from fastdeploy.scheduler.config import SchedulerConfig
|
||||||
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
|
from fastdeploy.utils import (
|
||||||
|
DeprecatedOptionWarning,
|
||||||
|
FlexibleArgumentParser,
|
||||||
|
is_port_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def nullable_str(x: str) -> Optional[str]:
|
def nullable_str(x: str) -> Optional[str]:
|
||||||
@@ -912,7 +916,7 @@ class EngineArgs:
|
|||||||
early_stop_args[k] = v
|
early_stop_args[k] = v
|
||||||
return EarlyStopConfig(early_stop_args)
|
return EarlyStopConfig(early_stop_args)
|
||||||
|
|
||||||
def create_engine_config(self) -> Config:
|
def create_engine_config(self) -> FDConfig:
|
||||||
"""
|
"""
|
||||||
Create and return a Config object based on the current settings.
|
Create and return a Config object based on the current settings.
|
||||||
"""
|
"""
|
||||||
@@ -947,8 +951,11 @@ class EngineArgs:
|
|||||||
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
|
||||||
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||||||
|
|
||||||
return Config(
|
assert is_port_available(
|
||||||
model_name_or_path=self.model,
|
"0.0.0.0", self.engine_worker_queue_port
|
||||||
|
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||||
|
|
||||||
|
return FDConfig(
|
||||||
model_config=model_cfg,
|
model_config=model_cfg,
|
||||||
scheduler_config=scheduler_cfg,
|
scheduler_config=scheduler_cfg,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
@@ -956,7 +963,6 @@ class EngineArgs:
|
|||||||
load_config=load_cfg,
|
load_config=load_cfg,
|
||||||
parallel_config=parallel_cfg,
|
parallel_config=parallel_cfg,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
speculative_config=speculative_cfg,
|
speculative_config=speculative_cfg,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
@@ -965,7 +971,6 @@ class EngineArgs:
|
|||||||
engine_worker_queue_port=self.engine_worker_queue_port,
|
engine_worker_queue_port=self.engine_worker_queue_port,
|
||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
# enable_mm=self.enable_mm,
|
|
||||||
reasoning_parser=self.reasoning_parser,
|
reasoning_parser=self.reasoning_parser,
|
||||||
tool_parser=self.tool_call_parser,
|
tool_parser=self.tool_call_parser,
|
||||||
splitwise_role=self.splitwise_role,
|
splitwise_role=self.splitwise_role,
|
||||||
@@ -973,10 +978,8 @@ class EngineArgs:
|
|||||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||||
graph_optimization_config=graph_opt_cfg,
|
graph_opt_config=graph_opt_cfg,
|
||||||
guided_decoding_backend=self.guided_decoding_backend,
|
guided_decoding_backend=self.guided_decoding_backend,
|
||||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||||
enable_logprob=self.enable_logprob,
|
|
||||||
early_stop_config=early_stop_cfg,
|
early_stop_config=early_stop_cfg,
|
||||||
load_choices=self.load_choices,
|
|
||||||
)
|
)
|
||||||
|
@@ -1,435 +0,0 @@
|
|||||||
"""
|
|
||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from fastdeploy.config import (
|
|
||||||
CacheConfig,
|
|
||||||
CommitConfig,
|
|
||||||
LoadConfig,
|
|
||||||
ModelConfig,
|
|
||||||
ParallelConfig,
|
|
||||||
)
|
|
||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
|
||||||
from fastdeploy.platforms import current_platform
|
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
|
||||||
from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""
|
|
||||||
Initial configuration class.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_config (ModelConfig): Model configuration object.
|
|
||||||
cache_config (CacheConfig): Cache configuration object.
|
|
||||||
model_name_or_path (str): Directory path to the model or the model name.
|
|
||||||
tokenizer (Optional[str]): Default is the model.
|
|
||||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens.
|
|
||||||
tensor_parallel_size (int): Tensor parallel size.
|
|
||||||
nnode (int): Number of nodes.
|
|
||||||
max_model_len (int): Maximum model length. Default is 8192.
|
|
||||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
|
||||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor.
|
|
||||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration.
|
|
||||||
use_warmup (bool): Flag to use warmup.
|
|
||||||
engine_worker_queue_port (int): Port for engine worker queue.
|
|
||||||
enable_mm (bool): Flag to enable multi-modal processing.
|
|
||||||
reasoning_parser(str): Flag specifies the reasoning parser to use for
|
|
||||||
extracting reasoning content from the model output
|
|
||||||
splitwise_role (str): Splitwise role.
|
|
||||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
|
|
||||||
Temporary configuration, will be removed in the future.
|
|
||||||
load_choices(str):The format of the model weights to load. .Default is default
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
load_config: LoadConfig,
|
|
||||||
commit_config: CommitConfig = CommitConfig(),
|
|
||||||
model_name_or_path: str = None,
|
|
||||||
tokenizer: str = None,
|
|
||||||
tensor_parallel_size: int = 8,
|
|
||||||
max_model_len: int = 8192,
|
|
||||||
max_num_seqs: int = 8,
|
|
||||||
max_num_batched_tokens: Optional[int] = None,
|
|
||||||
ips: str = None,
|
|
||||||
speculative_config: Optional[Dict[str, Any]] = None,
|
|
||||||
graph_optimization_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_warmup: bool = False,
|
|
||||||
engine_worker_queue_port: int = 8002,
|
|
||||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
# enable_mm: bool = False,
|
|
||||||
splitwise_role: str = "mixed",
|
|
||||||
innode_prefill_ports: Optional[List[int]] = None,
|
|
||||||
max_num_partial_prefills: int = 1,
|
|
||||||
max_long_partial_prefills: int = 1,
|
|
||||||
long_prefill_token_threshold: int = 0,
|
|
||||||
reasoning_parser: str = None,
|
|
||||||
tool_parser: str = None,
|
|
||||||
guided_decoding_backend: Optional[str] = None,
|
|
||||||
disable_any_whitespace: bool = False,
|
|
||||||
enable_logprob: bool = False,
|
|
||||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
|
||||||
load_choices: str = "default",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Config class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_config (ModelConfig): Model configuration object.
|
|
||||||
cache_config (CacheConfig): Cache configuration object.
|
|
||||||
parallel_config (ParallelConfig): Parallel configuration object.
|
|
||||||
scheduler_config (SchedulerConfig): Scheduler configuration object.
|
|
||||||
model_name_or_path (str): Model directory path or model name.
|
|
||||||
tokenizer (str): Default is the model.
|
|
||||||
tensor_parallel_size (int): Tensor parallel size. Default is 8.
|
|
||||||
max_model_len (int): Maximum model length. Default is 8192.
|
|
||||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
|
||||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
|
|
||||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
|
|
||||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
|
|
||||||
graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None.
|
|
||||||
use_warmup (bool): Flag to use warmup. Default is False.
|
|
||||||
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
|
|
||||||
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
|
|
||||||
splitwise_role (str): Splitwise role. Default is "mixed".
|
|
||||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Default is None.
|
|
||||||
reasoning_parser (str): Flag specifies the reasoning parser to use for
|
|
||||||
extracting reasoning content from the model output. Default is None.
|
|
||||||
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
|
||||||
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
|
||||||
Default is False.
|
|
||||||
enable_logprob(bool): Enable logprob. Default is False.
|
|
||||||
early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None.
|
|
||||||
load_choices(str):The format of the model weights to load. .Default is default
|
|
||||||
"""
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.load_config = load_config
|
|
||||||
self.commit_config = commit_config
|
|
||||||
self.model_name_or_path = model_name_or_path
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
self.ips = ips
|
|
||||||
|
|
||||||
if self.ips is None:
|
|
||||||
self.master_ip = "0.0.0.0"
|
|
||||||
elif isinstance(self.ips, list):
|
|
||||||
self.master_ip = self.ips[0]
|
|
||||||
else:
|
|
||||||
self.ips = self.ips.split(",")
|
|
||||||
self.master_ip = self.ips[0]
|
|
||||||
|
|
||||||
if self.ips is None:
|
|
||||||
self.nnode = 1
|
|
||||||
self.node_rank = 0
|
|
||||||
else:
|
|
||||||
self.nnode = len(self.ips)
|
|
||||||
|
|
||||||
for idx, ip in enumerate(self.ips):
|
|
||||||
if ip == self.master_ip:
|
|
||||||
self.node_rank = idx
|
|
||||||
|
|
||||||
self.max_model_len = max_model_len
|
|
||||||
self.max_num_seqs = max_num_seqs
|
|
||||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
|
||||||
self.mm_processor_kwargs = mm_processor_kwargs
|
|
||||||
# self.enable_mm = enable_mm
|
|
||||||
self.speculative_config = speculative_config
|
|
||||||
self.use_warmup = use_warmup
|
|
||||||
self.splitwise_role = splitwise_role
|
|
||||||
self.innode_prefill_ports = innode_prefill_ports
|
|
||||||
self.max_num_partial_prefills = max_num_partial_prefills
|
|
||||||
self.max_long_partial_prefills = max_long_partial_prefills
|
|
||||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
|
||||||
self.reasoning_parser = reasoning_parser
|
|
||||||
self.tool_parser = tool_parser
|
|
||||||
self.graph_optimization_config = graph_optimization_config
|
|
||||||
self.early_stop_config = early_stop_config
|
|
||||||
self.guided_decoding_backend = guided_decoding_backend
|
|
||||||
self.disable_any_whitespace = disable_any_whitespace
|
|
||||||
self._str_to_list("innode_prefill_ports", int)
|
|
||||||
self.load_choices = load_choices
|
|
||||||
|
|
||||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
|
||||||
|
|
||||||
import fastdeploy.model_executor.models # noqa: F401
|
|
||||||
|
|
||||||
architectures = self.model_config.architectures[0]
|
|
||||||
if MultimodalRegistry.contains_model(architectures):
|
|
||||||
self.enable_mm = True
|
|
||||||
else:
|
|
||||||
self.enable_mm = False
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
self.max_prefill_batch = 3
|
|
||||||
if current_platform.is_xpu():
|
|
||||||
self.max_prefill_batch = 1
|
|
||||||
if self.enable_mm:
|
|
||||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
|
||||||
|
|
||||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
|
||||||
assert (self.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
|
|
||||||
self.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
|
|
||||||
), "TP and EP cannot be enabled at the same time"
|
|
||||||
|
|
||||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
|
||||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
|
||||||
if num_ranks > self.max_chips_per_node:
|
|
||||||
self.worker_num_per_node = self.max_chips_per_node
|
|
||||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
|
||||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
|
||||||
else:
|
|
||||||
self.worker_num_per_node = num_ranks
|
|
||||||
|
|
||||||
self.engine_worker_queue_port = engine_worker_queue_port
|
|
||||||
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
|
|
||||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
|
||||||
if current_platform.is_xpu():
|
|
||||||
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
|
|
||||||
|
|
||||||
self.enable_logprob = enable_logprob
|
|
||||||
|
|
||||||
self.read_from_config()
|
|
||||||
self.postprocess()
|
|
||||||
self.check()
|
|
||||||
self.print()
|
|
||||||
|
|
||||||
def postprocess(self):
|
|
||||||
"""
|
|
||||||
calculate some parameters
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
self.device_ids.split(",").__len__() == self.worker_num_per_node
|
|
||||||
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
|
||||||
|
|
||||||
self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size]
|
|
||||||
|
|
||||||
self.host_ip = get_host_ip()
|
|
||||||
|
|
||||||
if self.ips is None or self.host_ip == self.master_ip:
|
|
||||||
self.is_master = True
|
|
||||||
else:
|
|
||||||
self.is_master = False
|
|
||||||
|
|
||||||
if self.tensor_parallel_size <= self.worker_num_per_node:
|
|
||||||
self.is_master = True
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
self.paddle_commit_id = paddle.version.commit
|
|
||||||
|
|
||||||
if self.max_num_batched_tokens is None:
|
|
||||||
if self.cache_config.enable_chunked_prefill:
|
|
||||||
self.max_num_batched_tokens = 2048
|
|
||||||
else:
|
|
||||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
|
||||||
self.max_num_batched_tokens = self.max_model_len
|
|
||||||
else:
|
|
||||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
|
||||||
|
|
||||||
if self.long_prefill_token_threshold == 0:
|
|
||||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
|
||||||
|
|
||||||
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
|
||||||
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
|
||||||
|
|
||||||
if self.guided_decoding_backend == "auto":
|
|
||||||
if self.enable_mm:
|
|
||||||
self.guided_decoding_backend = "off"
|
|
||||||
else:
|
|
||||||
self.guided_decoding_backend = "xgrammar"
|
|
||||||
|
|
||||||
def check(self):
|
|
||||||
"""
|
|
||||||
check the legality of config
|
|
||||||
"""
|
|
||||||
assert self.max_num_seqs <= 256, (
|
|
||||||
"The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
|
|
||||||
)
|
|
||||||
assert is_port_available(
|
|
||||||
"0.0.0.0", self.engine_worker_queue_port
|
|
||||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
|
||||||
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
|
|
||||||
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
|
|
||||||
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
|
||||||
assert self.max_num_batched_tokens >= self.max_num_seqs, (
|
|
||||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
|
||||||
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
|
|
||||||
)
|
|
||||||
assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
|
|
||||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
|
|
||||||
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
self.max_num_partial_prefills >= 1
|
|
||||||
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
self.max_long_partial_prefills >= 1
|
|
||||||
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
|
|
||||||
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
|
|
||||||
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
|
|
||||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.cache_config.enable_chunked_prefill:
|
|
||||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
|
||||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
|
||||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
|
||||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
|
|
||||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
|
||||||
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.max_num_partial_prefills > 1:
|
|
||||||
assert (
|
|
||||||
self.cache_config.enable_chunked_prefill is True
|
|
||||||
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
|
|
||||||
assert self.long_prefill_token_threshold < self.max_model_len, (
|
|
||||||
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
|
|
||||||
f" max_model_len: {self.max_model_len}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.guided_decoding_backend is not None:
|
|
||||||
assert self.guided_decoding_backend in [
|
|
||||||
"xgrammar",
|
|
||||||
"XGrammar",
|
|
||||||
"auto",
|
|
||||||
"off",
|
|
||||||
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
|
||||||
|
|
||||||
if self.guided_decoding_backend != "off":
|
|
||||||
# TODO: mm support guided_decoding
|
|
||||||
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
|
|
||||||
|
|
||||||
# TODO: speculative decoding support guided_decoding
|
|
||||||
|
|
||||||
# TODO: xpu support guided_decoding
|
|
||||||
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
|
|
||||||
|
|
||||||
try:
|
|
||||||
import xgrammar # noqa
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(
|
|
||||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scheduler_config.check()
|
|
||||||
|
|
||||||
def print(self, file=None):
|
|
||||||
"""
|
|
||||||
print all config
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (str): the path of file to save config
|
|
||||||
"""
|
|
||||||
llm_logger.info("=================== Configuration Information ===============")
|
|
||||||
for k, v in self.__dict__.items():
|
|
||||||
if k == "generation_config" and v is not None:
|
|
||||||
for gck, gcv in v.to_dict().items():
|
|
||||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
|
||||||
elif (
|
|
||||||
k == "cache_config"
|
|
||||||
or k == "model_config"
|
|
||||||
or k == "scheduler_config"
|
|
||||||
or k == "parallel_config"
|
|
||||||
or k == "commit_config"
|
|
||||||
):
|
|
||||||
v.print()
|
|
||||||
else:
|
|
||||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
|
||||||
llm_logger.info("=============================================================")
|
|
||||||
if file is not None:
|
|
||||||
f = open(file, "a")
|
|
||||||
now_time = datetime.now()
|
|
||||||
f.write(f"{now_time} configuration information as below,\n")
|
|
||||||
for k, v in self.__dict__.items():
|
|
||||||
f.write("{:<20}:{:<6}{}\n".format(k, "", v))
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
def init_cache_info(self):
|
|
||||||
"""
|
|
||||||
initialize cache info
|
|
||||||
"""
|
|
||||||
disaggregate_info = {}
|
|
||||||
if self.splitwise_role != "mixed":
|
|
||||||
disaggregate_info["role"] = self.splitwise_role
|
|
||||||
disaggregate_info["cache_info"] = dict()
|
|
||||||
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
|
|
||||||
disaggregate_info["transfer_protocol"] = current_protocol
|
|
||||||
for protocol in current_protocol:
|
|
||||||
if protocol == "ipc":
|
|
||||||
disaggregate_info["cache_info"][protocol] = {
|
|
||||||
"ip": self.host_ip,
|
|
||||||
"port": self.engine_worker_queue_port,
|
|
||||||
"device_ids": self.local_device_ids,
|
|
||||||
}
|
|
||||||
elif protocol == "rdma":
|
|
||||||
disaggregate_info["cache_info"][protocol] = {
|
|
||||||
"ip": self.host_ip,
|
|
||||||
"port": self.cache_config.pd_comm_port[0],
|
|
||||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
|
||||||
}
|
|
||||||
self.disaggregate_info = disaggregate_info
|
|
||||||
llm_logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
|
||||||
|
|
||||||
def read_from_config(self):
|
|
||||||
"""
|
|
||||||
reset model config from json file
|
|
||||||
"""
|
|
||||||
|
|
||||||
def reset_value(cls, value_name, key):
|
|
||||||
if hasattr(cls, key):
|
|
||||||
value = getattr(cls, key)
|
|
||||||
setattr(cls, value_name, value)
|
|
||||||
llm_logger.info(f"Reset parameter {value_name} = {value} from configuration.")
|
|
||||||
|
|
||||||
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
|
||||||
reset_value(
|
|
||||||
self.model_config,
|
|
||||||
"return_full_hidden_states",
|
|
||||||
"return_full_hidden_states",
|
|
||||||
)
|
|
||||||
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
|
||||||
|
|
||||||
def _check_master(self):
|
|
||||||
return self.is_master
|
|
||||||
|
|
||||||
def _str_to_list(self, attr_name, default_type):
|
|
||||||
if hasattr(self, attr_name):
|
|
||||||
val = getattr(self, attr_name)
|
|
||||||
if type(val) is str:
|
|
||||||
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
|
|
||||||
else:
|
|
||||||
setattr(self, attr_name, val)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return json.dumps(self.__dict__, indent=4)
|
|
@@ -105,7 +105,7 @@ class LLMEngine:
|
|||||||
cfg.reasoning_parser,
|
cfg.reasoning_parser,
|
||||||
cfg.limit_mm_per_prompt,
|
cfg.limit_mm_per_prompt,
|
||||||
cfg.mm_processor_kwargs,
|
cfg.mm_processor_kwargs,
|
||||||
cfg.enable_mm,
|
cfg.model_config.enable_mm,
|
||||||
cfg.tool_parser,
|
cfg.tool_parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
self.resource_manager = ResourceManagerV1(
|
self.resource_manager = ResourceManagerV1(
|
||||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role
|
||||||
)
|
)
|
||||||
if cfg.splitwise_role != "mixed":
|
if cfg.splitwise_role != "mixed":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -121,7 +121,7 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.resource_manager = ResourceManager(
|
self.resource_manager = ResourceManager(
|
||||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
cfg.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.splitwise_role
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||||
@@ -191,7 +191,7 @@ class LLMEngine:
|
|||||||
device_ids = self.cfg.device_ids.split(",")
|
device_ids = self.cfg.device_ids.split(",")
|
||||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||||
cache_config=self.cfg.cache_config,
|
cache_config=self.cfg.cache_config,
|
||||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
device_ids=device_ids,
|
device_ids=device_ids,
|
||||||
pod_ip=self.cfg.master_ip,
|
pod_ip=self.cfg.master_ip,
|
||||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||||
@@ -387,7 +387,7 @@ class LLMEngine:
|
|||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
block = True if len(added_requests) == 0 else False
|
block = True if len(added_requests) == 0 else False
|
||||||
if not self.cfg.enable_mm:
|
if not self.cfg.model_config.enable_mm:
|
||||||
err, data = self.zmq_server.receive_json_once(block)
|
err, data = self.zmq_server.receive_json_once(block)
|
||||||
else:
|
else:
|
||||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||||
@@ -807,7 +807,7 @@ class LLMEngine:
|
|||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.inference_start_time = time.time()
|
task.inference_start_time = time.time()
|
||||||
if not is_prefill:
|
if not is_prefill:
|
||||||
if not self.cfg.enable_mm:
|
if not self.cfg.model_config.enable_mm:
|
||||||
self.update_requests_chunk_size(tasks)
|
self.update_requests_chunk_size(tasks)
|
||||||
else:
|
else:
|
||||||
self.update_mm_requests_chunk_size(tasks)
|
self.update_mm_requests_chunk_size(tasks)
|
||||||
@@ -1049,7 +1049,7 @@ class LLMEngine:
|
|||||||
if self.cfg.splitwise_role == "prefill":
|
if self.cfg.splitwise_role == "prefill":
|
||||||
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
||||||
|
|
||||||
if self.cfg.enable_mm:
|
if self.cfg.model_config.enable_mm:
|
||||||
variables["FLAGS_max_partition_size"] = 1024
|
variables["FLAGS_max_partition_size"] = 1024
|
||||||
|
|
||||||
command_prefix = ""
|
command_prefix = ""
|
||||||
@@ -1084,9 +1084,9 @@ class LLMEngine:
|
|||||||
f" --devices {self.cfg.device_ids} {py_script}"
|
f" --devices {self.cfg.device_ids} {py_script}"
|
||||||
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
|
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
|
||||||
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
|
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
|
||||||
f" --model {self.cfg.model_name_or_path!s}"
|
f" --model {self.cfg.model_config.model!s}"
|
||||||
f" --device_ids {self.cfg.device_ids}"
|
f" --device_ids {self.cfg.device_ids}"
|
||||||
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
|
f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}"
|
||||||
f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}"
|
f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}"
|
||||||
f" --pod_ip {self.cfg.master_ip}"
|
f" --pod_ip {self.cfg.master_ip}"
|
||||||
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
||||||
@@ -1103,11 +1103,11 @@ class LLMEngine:
|
|||||||
f" --quantization {self.cfg.model_config.quantization}"
|
f" --quantization {self.cfg.model_config.quantization}"
|
||||||
f" --ori_vocab_size {ori_vocab_size}"
|
f" --ori_vocab_size {ori_vocab_size}"
|
||||||
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
||||||
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
f" --load_choices {self.cfg.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_append_flag = {
|
worker_append_flag = {
|
||||||
@@ -1118,8 +1118,7 @@ class LLMEngine:
|
|||||||
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
|
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
|
||||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||||
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||||
"enable_logprob": self.cfg.enable_logprob,
|
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||||
"enable_mm": self.cfg.enable_mm,
|
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_append_flag.items():
|
for worker_flag, value in worker_append_flag.items():
|
||||||
if value:
|
if value:
|
||||||
@@ -1216,7 +1215,7 @@ class LLMEngine:
|
|||||||
device_ids = self.cfg.device_ids.split(",")
|
device_ids = self.cfg.device_ids.split(",")
|
||||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||||
cache_config=self.cfg.cache_config,
|
cache_config=self.cfg.cache_config,
|
||||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
device_ids=device_ids,
|
device_ids=device_ids,
|
||||||
pod_ip=self.cfg.master_ip,
|
pod_ip=self.cfg.master_ip,
|
||||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||||
@@ -1370,7 +1369,7 @@ class LLMEngine:
|
|||||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||||
address=address,
|
address=address,
|
||||||
is_server=True,
|
is_server=True,
|
||||||
num_client=self.cfg.tensor_parallel_size,
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1382,7 +1381,7 @@ class LLMEngine:
|
|||||||
),
|
),
|
||||||
authkey=b"cache_queue_service",
|
authkey=b"cache_queue_service",
|
||||||
is_server=True,
|
is_server=True,
|
||||||
num_client=self.cfg.tensor_parallel_size,
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
client_id=-1,
|
client_id=-1,
|
||||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
)
|
)
|
||||||
@@ -1390,7 +1389,7 @@ class LLMEngine:
|
|||||||
self.engine_worker_queue = EngineWorkerQueue(
|
self.engine_worker_queue = EngineWorkerQueue(
|
||||||
address=address,
|
address=address,
|
||||||
is_server=False,
|
is_server=False,
|
||||||
num_client=self.cfg.tensor_parallel_size,
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
client_id=0,
|
client_id=0,
|
||||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||||
local_data_parallel_id=min(
|
local_data_parallel_id=min(
|
||||||
|
@@ -50,8 +50,8 @@ class ExpertService:
|
|||||||
cfg (Config): Config object containing all the configuration parameters.
|
cfg (Config): Config object containing all the configuration parameters.
|
||||||
"""
|
"""
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
|
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
|
||||||
end_pos = start_pos + self.cfg.tensor_parallel_size
|
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
|
||||||
if cfg.splitwise_role != "mixed":
|
if cfg.splitwise_role != "mixed":
|
||||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||||
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
|
||||||
@@ -69,13 +69,13 @@ class ExpertService:
|
|||||||
address=address,
|
address=address,
|
||||||
is_server=False,
|
is_server=False,
|
||||||
client_id=0,
|
client_id=0,
|
||||||
num_client=cfg.tensor_parallel_size,
|
num_client=cfg.parallel_config.tensor_parallel_size,
|
||||||
local_data_parallel_id=local_data_parallel_id,
|
local_data_parallel_id=local_data_parallel_id,
|
||||||
)
|
)
|
||||||
self.resource_manager = ResourceManager(
|
self.resource_manager = ResourceManager(
|
||||||
cfg.max_num_seqs,
|
cfg.max_num_seqs,
|
||||||
cfg,
|
cfg,
|
||||||
cfg.tensor_parallel_size,
|
cfg.parallel_config.tensor_parallel_size,
|
||||||
cfg.splitwise_role,
|
cfg.splitwise_role,
|
||||||
local_data_parallel_id,
|
local_data_parallel_id,
|
||||||
)
|
)
|
||||||
@@ -125,7 +125,7 @@ class ExpertService:
|
|||||||
if self.cfg.splitwise_role != "mixed":
|
if self.cfg.splitwise_role != "mixed":
|
||||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||||
cache_config=self.cfg.cache_config,
|
cache_config=self.cfg.cache_config,
|
||||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
device_ids=self.cfg.local_device_ids,
|
device_ids=self.cfg.local_device_ids,
|
||||||
pod_ip=self.cfg.master_ip,
|
pod_ip=self.cfg.master_ip,
|
||||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||||
@@ -343,7 +343,7 @@ class ExpertService:
|
|||||||
if not is_decode:
|
if not is_decode:
|
||||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||||
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
|
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
|
||||||
if not self.cfg.enable_mm:
|
if not self.cfg.model_config.enable_mm:
|
||||||
self.update_requests_chunk_size(tasks)
|
self.update_requests_chunk_size(tasks)
|
||||||
else:
|
else:
|
||||||
self.update_mm_requests_chunk_size(tasks)
|
self.update_mm_requests_chunk_size(tasks)
|
||||||
|
@@ -22,7 +22,7 @@ import uuid
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.config import ModelConfig
|
from fastdeploy.config import ModelConfig
|
||||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||||
from fastdeploy.input.preprocess import InputPreprocessor
|
from fastdeploy.input.preprocess import InputPreprocessor
|
||||||
|
@@ -16,8 +16,7 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastdeploy.config import ErnieArchitectures
|
from fastdeploy.config import ErnieArchitectures, ModelConfig
|
||||||
from fastdeploy.engine.config import ModelConfig
|
|
||||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from fastdeploy.reasoning import ReasoningParserManager
|
from fastdeploy.reasoning import ReasoningParserManager
|
||||||
|
|
||||||
|
@@ -43,7 +43,6 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
|
|||||||
Ernie4_5_MLP,
|
Ernie4_5_MLP,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
@@ -504,7 +503,6 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@MultimodalRegistry.register_model()
|
|
||||||
class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_VLMoeForConditionalGeneration
|
Ernie4_5_VLMoeForConditionalGeneration
|
||||||
|
@@ -22,7 +22,7 @@ class MultimodalRegistry:
|
|||||||
A registry for multimodal models
|
A registry for multimodal models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mm_models: set[str] = set()
|
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_model(cls, name: str = "") -> Callable:
|
def register_model(cls, name: str = "") -> Callable:
|
||||||
|
@@ -57,7 +57,7 @@ class TokenProcessor:
|
|||||||
self.split_connector = split_connector
|
self.split_connector = split_connector
|
||||||
|
|
||||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||||
self.use_logprobs = self.cfg.enable_logprob
|
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
self.output_tokens = paddle.full(
|
self.output_tokens = paddle.full(
|
||||||
|
@@ -320,7 +320,7 @@ class SplitwiseConnector:
|
|||||||
"""
|
"""
|
||||||
self.connect_innode_instances[port] = EngineWorkerQueue(
|
self.connect_innode_instances[port] = EngineWorkerQueue(
|
||||||
address=("0.0.0.0", int(port)),
|
address=("0.0.0.0", int(port)),
|
||||||
num_client=self.cfg.tensor_parallel_size,
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||||
client_id=0,
|
client_id=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -587,7 +587,6 @@ def parse_args():
|
|||||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||||
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--enable_mm", action="store_true", help="Whether to enable vl model")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_logprob",
|
"--enable_logprob",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -708,8 +707,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
else:
|
else:
|
||||||
logger.info("No quantization config found and use original weight and act dtype.")
|
logger.info("No quantization config found and use original weight and act dtype.")
|
||||||
|
|
||||||
# Set VL tag
|
|
||||||
model_config.enable_mm = args.enable_mm
|
|
||||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||||
|
|
||||||
|
@@ -16,7 +16,12 @@
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig, GraphOptimizationConfig, ParallelConfig
|
from fastdeploy.config import (
|
||||||
|
CacheConfig,
|
||||||
|
FDConfig,
|
||||||
|
GraphOptimizationConfig,
|
||||||
|
ParallelConfig,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
@@ -144,7 +149,13 @@ def run_test_case():
|
|||||||
graph_opt_config.use_cudagraph = True
|
graph_opt_config.use_cudagraph = True
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
parallel_config.max_num_seqs = 1
|
parallel_config.max_num_seqs = 1
|
||||||
fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config)
|
cache_config = CacheConfig({})
|
||||||
|
# Initialize cuda graph capture list
|
||||||
|
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
|
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
|
fd_config = FDConfig(
|
||||||
|
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
# Run Test Case1
|
# Run Test Case1
|
||||||
test_model1 = TestModel1(fd_config=fd_config)
|
test_model1 = TestModel1(fd_config=fd_config)
|
||||||
|
@@ -16,7 +16,12 @@
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig, GraphOptimizationConfig, ParallelConfig
|
from fastdeploy.config import (
|
||||||
|
CacheConfig,
|
||||||
|
FDConfig,
|
||||||
|
GraphOptimizationConfig,
|
||||||
|
ParallelConfig,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
@@ -90,7 +95,13 @@ def run_test_case():
|
|||||||
graph_opt_config.use_cudagraph = True
|
graph_opt_config.use_cudagraph = True
|
||||||
parallel_config = ParallelConfig(args={})
|
parallel_config = ParallelConfig(args={})
|
||||||
parallel_config.max_num_seqs = 1
|
parallel_config.max_num_seqs = 1
|
||||||
fd_config = FDConfig(graph_opt_config=graph_opt_config, parallel_config=parallel_config)
|
cache_config = CacheConfig({})
|
||||||
|
# Initialize cuda graph capture list
|
||||||
|
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
|
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
|
fd_config = FDConfig(
|
||||||
|
graph_opt_config=graph_opt_config, parallel_config=parallel_config, cache_config=cache_config, test_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
# Run Test Case1
|
# Run Test Case1
|
||||||
test_model1 = TestModel1(fd_config=fd_config)
|
test_model1 = TestModel1(fd_config=fd_config)
|
||||||
|
81
tests/utils/test_config.py
Normal file
81
tests/utils/test_config.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.config import (
|
||||||
|
CacheConfig,
|
||||||
|
FDConfig,
|
||||||
|
GraphOptimizationConfig,
|
||||||
|
ParallelConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(unittest.TestCase):
|
||||||
|
def test_fdconfig_nnode(self):
|
||||||
|
parallel_config = ParallelConfig({"tensor_parallel_size": 16, "expert_parallel_size": 1})
|
||||||
|
graph_opt_config = GraphOptimizationConfig({})
|
||||||
|
cache_config = CacheConfig({})
|
||||||
|
fd_config = FDConfig(
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
ips=["1.1.1.1", "0.0.0.0"],
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
assert fd_config.nnode == 2
|
||||||
|
assert fd_config.is_master is False
|
||||||
|
|
||||||
|
def test_fdconfig_ips(self):
|
||||||
|
parallel_config = ParallelConfig({})
|
||||||
|
graph_opt_config = GraphOptimizationConfig({})
|
||||||
|
cache_config = CacheConfig({})
|
||||||
|
fd_config = FDConfig(
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
ips="0.0.0.0",
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
assert fd_config.master_ip == "0.0.0.0"
|
||||||
|
|
||||||
|
def test_fdconfig_max_num_tokens(self):
|
||||||
|
parallel_config = ParallelConfig({})
|
||||||
|
graph_opt_config = GraphOptimizationConfig({})
|
||||||
|
cache_config = CacheConfig({})
|
||||||
|
cache_config.enable_chunked_prefill = True
|
||||||
|
fd_config = FDConfig(
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
ips="0.0.0.0",
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
assert fd_config.max_num_batched_tokens == 2048
|
||||||
|
|
||||||
|
cache_config.enable_chunked_prefill = False
|
||||||
|
fd_config = FDConfig(
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
ips="0.0.0.0",
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
assert fd_config.max_num_batched_tokens == 8192
|
||||||
|
|
||||||
|
def test_fdconfig_init_cache(self):
|
||||||
|
parallel_config = ParallelConfig({})
|
||||||
|
graph_opt_config = GraphOptimizationConfig({})
|
||||||
|
cache_config = CacheConfig({})
|
||||||
|
cache_config.cache_transfer_protocol = "rdma,ipc"
|
||||||
|
cache_config.pd_comm_port = "2334"
|
||||||
|
fd_config = FDConfig(
|
||||||
|
parallel_config=parallel_config,
|
||||||
|
graph_opt_config=graph_opt_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
splitwise_role="prefill",
|
||||||
|
test_mode=True,
|
||||||
|
)
|
||||||
|
fd_config.init_cache_info()
|
||||||
|
assert fd_config.disaggregate_info["role"] == "prefill"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user