mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] fix num of rdma_comm_ports check (#5168)
* fix num of rdma_comm_ports check * update * update * update
This commit is contained in:
@@ -507,9 +507,20 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
|
||||
)
|
||||
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
|
||||
num_nodes = len(self.ips) if self.ips else 1
|
||||
if self.data_parallel_size % num_nodes != 0:
|
||||
raise ValueError(
|
||||
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
|
||||
f"data_parallel_size ({self.data_parallel_size}) must be divisible by "
|
||||
f"num_nodes ({num_nodes})."
|
||||
)
|
||||
dp_per_node = self.data_parallel_size // num_nodes
|
||||
expected_ports = self.tensor_parallel_size * dp_per_node
|
||||
if len(self.rdma_comm_ports) != expected_ports:
|
||||
raise ValueError(
|
||||
f"The number of rdma_comm_ports must equal "
|
||||
f"tensor_parallel_size * (data_parallel_size / num_nodes) = "
|
||||
f"{self.tensor_parallel_size} * ({self.data_parallel_size} / {num_nodes}) "
|
||||
f"= {expected_ports}, but got {len(self.rdma_comm_ports)}."
|
||||
)
|
||||
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||
|
||||
Reference in New Issue
Block a user