[LLM] fix multinode bugs (#2945)

* [LLM] fix multinode bugs

* [LLM] fix multinode bugs

* [LLM] fix multinode bugs

* [LLM] fix ci bugs

* fix ci bugs

* fix ci bugs
This commit is contained in:
ltd0924
2025-07-22 20:23:37 +08:00
committed by GitHub
parent 69be77c8c0
commit b0f1e0eef4
9 changed files with 68 additions and 87 deletions

View File

@@ -143,7 +143,7 @@ class PaddleDisWorkerProc():
# Initialize task queue
task_address = (self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port)
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
@@ -162,7 +162,6 @@ class PaddleDisWorkerProc():
model_weights_status:
"""
# init worker_ready_signal
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
array_size = min(
self.max_chips_per_node, self.parallel_config.tensor_parallel_size *
self.parallel_config.expert_parallel_size)
@@ -183,9 +182,9 @@ class PaddleDisWorkerProc():
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_healthy_live_signal.value[self.local_rank % 8] = int(
time.time())
create=False,
)
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
@@ -271,8 +270,7 @@ class PaddleDisWorkerProc():
paddle.distributed.barrier()
self.insert_step = False
self.worker_healthy_live_signal.value[self.local_rank] = int(
time.time())
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue
if self.local_rank % mp_num_per_node == 0:
@@ -388,7 +386,7 @@ class PaddleDisWorkerProc():
suffix=self.parallel_config.engine_pid,
create=False)
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_local
self.local_rank % self.max_chips_per_node] = num_blocks_local
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
@@ -396,7 +394,7 @@ class PaddleDisWorkerProc():
num_blocks_global = self.get_profile_block_num_signal.value.min(
).item()
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_global
self.local_rank % self.max_chips_per_node] = num_blocks_global
else:
num_blocks_global = self.fd_config.parallel_config.total_block_num
# NOTE(liuzichang): Too big num_blocks_global will lead to error 700