[BugFix]fix dp&ep&tp and muti node infer (#3629)

* rm log

* fix bug

* fix bug

* fix dp&ep&tp and muti node infer

* fix

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
gaoziyuan
2025-08-28 19:09:10 +08:00
committed by GitHub
parent 17731a8acd
commit fc635acc47
7 changed files with 48 additions and 34 deletions

View File

@@ -170,6 +170,7 @@ class PaddleDisWorkerProc:
self.max_chips_per_node,
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
)
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
@@ -179,7 +180,6 @@ class PaddleDisWorkerProc:
create=False,
)
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
@@ -231,7 +231,6 @@ class PaddleDisWorkerProc:
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
logger.info("gaoziyuan test init_health_status")
def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distrubuted Workers.
@@ -255,7 +254,8 @@ class PaddleDisWorkerProc:
self.insert_step = False
req_dicts = None
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue
if self.local_rank % mp_num_per_node == 0:
@@ -267,7 +267,7 @@ class PaddleDisWorkerProc:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 1
self.exist_task_signal.value[0] = 1
if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers
@@ -285,17 +285,14 @@ class PaddleDisWorkerProc:
self.parallel_config.engine_pid,
)
if (
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] == 1
or self.task_queue.read_finish_flag.get() == 1
):
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 0
self.exist_task_signal.value[0] = 0
self.task_queue.read_finish_flag.set(0)
req_dicts = []
@@ -413,7 +410,7 @@ class PaddleDisWorkerProc:
is_server=False,
num_client=self.parallel_config.tensor_parallel_size,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.parallel_config.expert_parallel_rank,
local_data_parallel_id=self.parallel_config.data_parallel_rank,
)
def load_model(self) -> None: