diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index db9aa6b61..05781352d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -507,9 +507,20 @@ class EngineArgs: raise ValueError( "Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol." ) - if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size: + num_nodes = len(self.ips) if self.ips else 1 + if self.data_parallel_size % num_nodes != 0: raise ValueError( - f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}." + f"data_parallel_size ({self.data_parallel_size}) must be divisible by " + f"num_nodes ({num_nodes})." + ) + dp_per_node = self.data_parallel_size // num_nodes + expected_ports = self.tensor_parallel_size * dp_per_node + if len(self.rdma_comm_ports) != expected_ports: + raise ValueError( + f"The number of rdma_comm_ports must equal " + f"tensor_parallel_size * (data_parallel_size / num_nodes) = " + f"{self.tensor_parallel_size} * ({self.data_parallel_size} / {num_nodes}) " + f"= {expected_ports}, but got {len(self.rdma_comm_ports)}." ) if not current_platform.is_cuda() and not current_platform.is_xpu():