diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 668229183..e5b184b3c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -417,17 +417,11 @@ class PaddleDisWorkerProc: while True: # run eplb self._run_eplb(tp_rank) - if tp_rank == 0: + + if self.fd_config.load_config.dynamic_load_weight: if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) - if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel: - self.model_weights_signal[0] = self._broadcast_model_weights_signal( - src=0, group=self.parallel_config.ep_group - ) - if self.fd_config.load_config.dynamic_load_weight and tp_size > 1: - self.model_weights_signal[0] = self._broadcast_model_weights_signal( - src=0, group=self.parallel_config.tp_group - ) + self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None) req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) @@ -447,10 +441,7 @@ class PaddleDisWorkerProc: self._tp_barrier_wait() if tp_size > 1 else None if self.fd_config.load_config.dynamic_load_weight: - if self.parallel_config.enable_expert_parallel: - paddle.distributed.barrier(self.parallel_config.ep_group) - else: - paddle.distributed.barrier(self.parallel_config.tp_group) + paddle.distributed.barrier() if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL: logger.info( f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"