[BugFix] fix num of rdma_comm_ports check (#5168)

* fix num of rdma_comm_ports check

* update

* update

* update
This commit is contained in:
Yuanle Liu
2025-11-21 18:31:14 +08:00
committed by GitHub
parent d2298dcb0c
commit 5bcf79d780

View File

@@ -507,9 +507,20 @@ class EngineArgs:
raise ValueError(
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
)
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
num_nodes = len(self.ips) if self.ips else 1
if self.data_parallel_size % num_nodes != 0:
raise ValueError(
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
f"data_parallel_size ({self.data_parallel_size}) must be divisible by "
f"num_nodes ({num_nodes})."
)
dp_per_node = self.data_parallel_size // num_nodes
expected_ports = self.tensor_parallel_size * dp_per_node
if len(self.rdma_comm_ports) != expected_ports:
raise ValueError(
f"The number of rdma_comm_ports must equal "
f"tensor_parallel_size * (data_parallel_size / num_nodes) = "
f"{self.tensor_parallel_size} * ({self.data_parallel_size} / {num_nodes}) "
f"= {expected_ports}, but got {len(self.rdma_comm_ports)}."
)
if not current_platform.is_cuda() and not current_platform.is_xpu():