[BugFix] fix rl model_weights_signal to support tp>1 (#5639)

This commit is contained in:
Yuanle Liu
2025-12-18 20:43:58 +08:00
committed by GitHub
parent d739af5e6e
commit b47674c796

View File

@@ -417,17 +417,11 @@ class PaddleDisWorkerProc:
while True:
# run eplb
self._run_eplb(tp_rank)
if tp_rank == 0:
if self.fd_config.load_config.dynamic_load_weight:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=self.parallel_config.ep_group
)
if self.fd_config.load_config.dynamic_load_weight and tp_size > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=self.parallel_config.tp_group
)
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
req_dicts = None
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())
@@ -447,10 +441,7 @@ class PaddleDisWorkerProc:
self._tp_barrier_wait() if tp_size > 1 else None
if self.fd_config.load_config.dynamic_load_weight:
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.barrier()
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"