[Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)

* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports

* [fix] fix some bugs

* [fix] fix rdma port for cache manager/messager

* [fix] temporarily cancel port availability check to see if it can pass ci test

* [feat] simplify args for multi api server

* [fix] fix dp

* [fix] fix port for xpu

* [fix] add tests for ports post processing & fix ci

* [test] fix test_multi_api_server

* [fix] fix rdma_comm_ports args for multi_api_server

* [fix] fix test_common_engine

* [fix] fix test_cache_transfer_manager

* [chore] automatically setting FD_ENABLE_MULTI_API_SERVER

* [fix] avoid api server from creating engine_args twice

* [fix] fix test_run_batch

* [fix] fix test_metrics

* [fix] fix splitwise connector init

* [test] add test_rdma_transfer and test_expert_service

* [fix] fix code syntax

* [fix] fix test_rdma_transfer and build wheel with rdma script
This commit is contained in:
Yonghua Li
2025-12-17 15:50:42 +08:00
committed by GitHub
parent cdc0004894
commit 0c8c6369ed
34 changed files with 1323 additions and 409 deletions

View File

@@ -47,7 +47,9 @@ from fastdeploy.utils import (
DeprecatedOptionWarning,
FlexibleArgumentParser,
console_logger,
find_free_ports,
is_port_available,
parse_ports,
parse_quantization,
)
@@ -224,7 +226,7 @@ class EngineArgs:
The amount of CPU memory to offload to.
"""
cache_queue_port: str = "0"
cache_queue_port: Optional[Union[int, str, list]] = None
"""
Port for cache queue.
"""
@@ -266,7 +268,7 @@ class EngineArgs:
# This optimization is enabled by default, and can be disabled by using this flag.
"""
engine_worker_queue_port: str = "0"
engine_worker_queue_port: Optional[Union[int, str, list]] = None
"""
Port for worker queue communication.
"""
@@ -301,17 +303,17 @@ class EngineArgs:
Chunk size of moe input.
"""
cache_transfer_protocol: str = "ipc"
cache_transfer_protocol: str = "ipc,rdma"
"""
Protocol to use for cache transfer.
"""
pd_comm_port: Optional[List[int]] = None
pd_comm_port: Optional[Union[int, str, list]] = None
"""
Port for splitwise communication.
"""
rdma_comm_ports: Optional[List[int]] = None
rdma_comm_ports: Optional[Union[int, str, list]] = None
"""
Ports for rdma communication.
"""
@@ -497,6 +499,11 @@ class EngineArgs:
Flag to rollout routing replay(r3)
"""
skip_port_check: bool = False
"""
Whether to skip port availability check. Default is False (not skip).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -508,8 +515,6 @@ class EngineArgs:
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
self.enable_prefix_caching = False
# if self.dynamic_load_weight:
# self.enable_prefix_caching = False
if self.enable_logprob:
if not current_platform.is_cuda() and not current_platform.is_xpu():
raise NotImplementedError("Only CUDA and XPU platforms support logprob.")
@@ -530,33 +535,69 @@ class EngineArgs:
f"scheduler, please provide --router argument."
)
if "rdma" in self.cache_transfer_protocol:
if self.rdma_comm_ports is None:
raise ValueError(
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
)
num_nodes = len(self.ips) if self.ips else 1
if self.data_parallel_size % num_nodes != 0:
raise ValueError(
f"data_parallel_size ({self.data_parallel_size}) must be divisible by "
f"num_nodes ({num_nodes})."
)
dp_per_node = self.data_parallel_size // num_nodes
expected_ports = self.tensor_parallel_size * dp_per_node
if len(self.rdma_comm_ports) != expected_ports:
raise ValueError(
f"The number of rdma_comm_ports must equal "
f"tensor_parallel_size * (data_parallel_size / num_nodes) = "
f"{self.tensor_parallel_size} * ({self.data_parallel_size} / {num_nodes}) "
f"= {expected_ports}, but got {len(self.rdma_comm_ports)}."
)
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
envs.FD_ENABLE_MAX_PREFILL = 1
self.post_init_all_ports()
def post_init_all_ports(self):
def post_init_ports(name: str, ports: list, num_total_ports: int):
ports = parse_ports(ports)
num_cur_dp_ports = num_total_ports
if envs.FD_ENABLE_MULTI_API_SERVER:
num_cur_dp_ports //= self.data_parallel_size
if ports is None:
ports = find_free_ports(num_ports=num_cur_dp_ports)
console_logger.info(
f"Parameter `{name}` is not specified, found available ports for possible use: {ports}"
)
else:
num_input_ports = len(ports)
if num_input_ports != num_total_ports:
ports = find_free_ports(num_ports=num_cur_dp_ports)
console_logger.warn(
f"Parameter `{name}` expects {num_total_ports} ports, but got {num_input_ports}. Ignore them and assign new ones: {ports}"
)
else:
console_logger.info(f"Using `{name}`: {ports}")
if not self.skip_port_check:
for port in ports:
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
console_logger.debug(f"post init {name}: {ports}")
return ports
num_nodes = len(self.ips) if self.ips else 1
if self.data_parallel_size % num_nodes != 0:
raise ValueError(
f"data_parallel_size ({self.data_parallel_size}) must be divisible by num_nodes ({num_nodes})."
)
self.engine_worker_queue_port = post_init_ports(
"engine_worker_queue_port",
self.engine_worker_queue_port,
self.data_parallel_size // num_nodes,
)
self.cache_queue_port = post_init_ports(
"cache_queue_port",
self.cache_queue_port,
self.data_parallel_size // num_nodes,
)
self.rdma_comm_ports = post_init_ports(
"rdma_comm_ports",
self.rdma_comm_ports,
self.tensor_parallel_size * self.data_parallel_size // num_nodes,
)
self.pd_comm_port = post_init_ports(
"pd_comm_port",
self.pd_comm_port,
self.data_parallel_size // num_nodes,
)
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""
@@ -1166,7 +1207,7 @@ class EngineArgs:
return parser
@classmethod
def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs":
def from_cli_args(cls, args: FlexibleArgumentParser, skip_port_check=False) -> "EngineArgs":
"""
Create an instance of EngineArgs from command line arguments.
"""
@@ -1174,7 +1215,7 @@ class EngineArgs:
for field in dataclass_fields(cls):
if hasattr(args, field.name):
args_dict[field.name] = getattr(args, field.name)
return cls(**args_dict)
return cls(**args_dict, skip_port_check=skip_port_check)
def create_speculative_config(self) -> SpeculativeConfig:
""" """
@@ -1253,7 +1294,7 @@ class EngineArgs:
routing_replay_args[k] = v
return RoutingReplayConfig(routing_replay_args)
def create_engine_config(self, port_availability_check=True) -> FDConfig:
def create_engine_config(self) -> FDConfig:
"""
Create and return a Config object based on the current settings.
"""
@@ -1282,11 +1323,6 @@ class EngineArgs:
else:
self.max_num_batched_tokens = self.max_model_len
if isinstance(self.engine_worker_queue_port, int):
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
@@ -1302,10 +1338,6 @@ class EngineArgs:
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=all_dict)
if port_availability_check:
assert is_port_available(
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
return FDConfig(
model_config=model_cfg,