[Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)

* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports

* [fix] fix some bugs

* [fix] fix rdma port for cache manager/messager

* [fix] temporarily cancel port availability check to see if it can pass ci test

* [feat] simplify args for multi api server

* [fix] fix dp

* [fix] fix port for xpu

* [fix] add tests for ports post processing & fix ci

* [test] fix test_multi_api_server

* [fix] fix rdma_comm_ports args for multi_api_server

* [fix] fix test_common_engine

* [fix] fix test_cache_transfer_manager

* [chore] automatically setting FD_ENABLE_MULTI_API_SERVER

* [fix] avoid api server from creating engine_args twice

* [fix] fix test_run_batch

* [fix] fix test_metrics

* [fix] fix splitwise connector init

* [test] add test_rdma_transfer and test_expert_service

* [fix] fix code syntax

* [fix] fix test_rdma_transfer and build wheel with rdma script
This commit is contained in:
Yonghua Li
2025-12-17 15:50:42 +08:00
committed by GitHub
parent cdc0004894
commit 0c8c6369ed
34 changed files with 1323 additions and 409 deletions

View File

@@ -173,11 +173,7 @@ class PaddleDisWorkerProc:
exist_swapped_task_signal:
model_weights_status:
"""
if (
self.parallel_config.enable_expert_parallel
and self.parallel_config.data_parallel_size > 1
and not envs.FD_ENABLE_MULTI_API_SERVER
):
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
)
@@ -217,7 +213,7 @@ class PaddleDisWorkerProc:
name="worker_healthy_live_signal",
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
@@ -229,7 +225,7 @@ class PaddleDisWorkerProc:
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
@@ -239,7 +235,7 @@ class PaddleDisWorkerProc:
name="exist_task_signal",
array=workers_exist_task,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
@@ -249,7 +245,7 @@ class PaddleDisWorkerProc:
name="exist_swapped_task_signal",
array=workers_swapped_task,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
@@ -259,7 +255,7 @@ class PaddleDisWorkerProc:
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
@@ -304,11 +300,11 @@ class PaddleDisWorkerProc:
rank=self.local_rank,
ep_size=self.ranks,
fd_config=self.fd_config,
ipc_signal_suffix=self.parallel_config.engine_worker_queue_port,
ipc_signal_suffix=self.parallel_config.local_engine_worker_queue_port,
)
dp_ipc_signal_suffix = (
f"{self.parallel_config.engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}"
f"{self.parallel_config.local_engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}"
)
if local_rank == 0: # master rank0
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
@@ -355,7 +351,7 @@ class PaddleDisWorkerProc:
[MODEL_MAIN_NAME],
self.local_rank,
self.ranks,
shm_uuid=self.parallel_config.engine_worker_queue_port,
shm_uuid=self.parallel_config.local_engine_worker_queue_port,
eplb_config=self.eplb_config,
logger=logger,
)
@@ -468,7 +464,7 @@ class PaddleDisWorkerProc:
self.model_weights_status,
# model_weights_signal
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
self.parallel_config.local_engine_worker_queue_port,
)
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
@@ -596,10 +592,10 @@ class PaddleDisWorkerProc:
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
self.parallel_config.local_engine_worker_queue_port,
)
else:
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock"
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.local_engine_worker_queue_port}.sock"
logger.info(f"connect task queue address {task_address}")
self.task_queue = TaskQueue(
address=task_address,
@@ -937,10 +933,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset
if args.load_strategy != "meta":
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
parallel_config.local_data_parallel_id
]
parallel_config.set_communicate_group()
load_config = LoadConfig(vars(args))
@@ -1015,6 +1007,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
eplb_config=eplb_config,
routing_replay_config=routing_replay_config,
)
logger.info(f"parallel_config.local_engine_worker_queue_port {parallel_config.local_engine_worker_queue_port}")
update_fd_config_for_mm(fd_config)
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):
fd_config.load_config.load_choices = "default"