diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index 202891bc0..9acac7e96 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -101,6 +101,9 @@ class EngineWorkerQueue: self.finish_request_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] + self.worker_process_tp_barrier = [ + threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) + ] self.finish_add_cache_task_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) @@ -193,6 +196,10 @@ class EngineWorkerQueue: "get_finish_add_cache_task_barrier", callable=lambda idx: self.finish_add_cache_task_barrier[idx], ) + QueueManager.register( + "get_worker_process_tp_barrier", + callable=lambda idx: self.worker_process_tp_barrier[idx], + ) self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey) self.manager.start() else: @@ -217,6 +224,7 @@ class EngineWorkerQueue: QueueManager.register("get_connect_rdma_tasks") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") + QueueManager.register("get_worker_process_tp_barrier") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -239,6 +247,7 @@ class EngineWorkerQueue: self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( self.local_data_parallel_id ) + self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue( self.local_data_parallel_id diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 902bf9461..44d3e0150 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -256,6 +256,12 @@ class PaddleDisWorkerProc: paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) return model_weights_signal_tensor.item() + def _tp_barrier_wait(self): + if current_platform.is_xpu(): + self.task_queue.worker_process_tp_barrier.wait() + else: + paddle.distributed.barrier(self.parallel_config.tp_group) + def event_loop_normal(self) -> None: """Main event loop for Paddle Distributed Workers. TODO(gongshaotian): support remote calling of functions that control worker. @@ -299,7 +305,7 @@ class PaddleDisWorkerProc: if self.parallel_config.tensor_parallel_size > 1: # Synchronize the signal for other workers - paddle.distributed.barrier(self.parallel_config.tp_group) + self._tp_barrier_wait() if self.fd_config.load_config.dynamic_load_weight: if self.parallel_config.enable_expert_parallel: @@ -350,7 +356,7 @@ class PaddleDisWorkerProc: if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()): if self.ranks > 1: - paddle.distributed.barrier(self.parallel_config.tp_group) + self._tp_barrier_wait() time.sleep(0.001) continue