mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[xpu] use cpu barrier (#4181)
This commit is contained in:
@@ -101,6 +101,9 @@ class EngineWorkerQueue:
|
|||||||
self.finish_request_barrier = [
|
self.finish_request_barrier = [
|
||||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||||
]
|
]
|
||||||
|
self.worker_process_tp_barrier = [
|
||||||
|
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
self.finish_add_cache_task_barrier = [
|
self.finish_add_cache_task_barrier = [
|
||||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||||
@@ -193,6 +196,10 @@ class EngineWorkerQueue:
|
|||||||
"get_finish_add_cache_task_barrier",
|
"get_finish_add_cache_task_barrier",
|
||||||
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
|
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
|
||||||
)
|
)
|
||||||
|
QueueManager.register(
|
||||||
|
"get_worker_process_tp_barrier",
|
||||||
|
callable=lambda idx: self.worker_process_tp_barrier[idx],
|
||||||
|
)
|
||||||
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
||||||
self.manager.start()
|
self.manager.start()
|
||||||
else:
|
else:
|
||||||
@@ -217,6 +224,7 @@ class EngineWorkerQueue:
|
|||||||
QueueManager.register("get_connect_rdma_tasks")
|
QueueManager.register("get_connect_rdma_tasks")
|
||||||
QueueManager.register("get_connect_rdma_tasks_responses")
|
QueueManager.register("get_connect_rdma_tasks_responses")
|
||||||
QueueManager.register("get_connect_task_lock")
|
QueueManager.register("get_connect_task_lock")
|
||||||
|
QueueManager.register("get_worker_process_tp_barrier")
|
||||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||||
self._connect_with_retry()
|
self._connect_with_retry()
|
||||||
|
|
||||||
@@ -239,6 +247,7 @@ class EngineWorkerQueue:
|
|||||||
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier(
|
||||||
self.local_data_parallel_id
|
self.local_data_parallel_id
|
||||||
)
|
)
|
||||||
|
self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id)
|
||||||
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
|
||||||
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
self.finished_add_cache_task_queue = self.manager.get_finish_add_cache_task_queue(
|
||||||
self.local_data_parallel_id
|
self.local_data_parallel_id
|
||||||
|
@@ -256,6 +256,12 @@ class PaddleDisWorkerProc:
|
|||||||
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
||||||
return model_weights_signal_tensor.item()
|
return model_weights_signal_tensor.item()
|
||||||
|
|
||||||
|
def _tp_barrier_wait(self):
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
self.task_queue.worker_process_tp_barrier.wait()
|
||||||
|
else:
|
||||||
|
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||||
|
|
||||||
def event_loop_normal(self) -> None:
|
def event_loop_normal(self) -> None:
|
||||||
"""Main event loop for Paddle Distributed Workers.
|
"""Main event loop for Paddle Distributed Workers.
|
||||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||||
@@ -299,7 +305,7 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
# Synchronize the signal for other workers
|
# Synchronize the signal for other workers
|
||||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
self._tp_barrier_wait()
|
||||||
|
|
||||||
if self.fd_config.load_config.dynamic_load_weight:
|
if self.fd_config.load_config.dynamic_load_weight:
|
||||||
if self.parallel_config.enable_expert_parallel:
|
if self.parallel_config.enable_expert_parallel:
|
||||||
@@ -350,7 +356,7 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
|
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
|
||||||
if self.ranks > 1:
|
if self.ranks > 1:
|
||||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
self._tp_barrier_wait()
|
||||||
|
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
|
Reference in New Issue
Block a user