[fix] fix ep group all-reduce (#4140)

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* fix model_weights_signal (#4092)

* fix model_weights_signal

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
李泳桦
2025-09-18 10:34:49 +08:00
committed by GitHub
parent cffde70949
commit 0fa28b1068
6 changed files with 41 additions and 26 deletions

View File

@@ -270,6 +270,11 @@ class PaddleDisWorkerProc:
create=False,
)
def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
return model_weights_signal_tensor.item()
def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -279,15 +284,19 @@ class PaddleDisWorkerProc:
req_ids = []
num_running_requests = 0
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
if self.fd_config.load_config.dynamic_load_weight:
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=self.parallel_config.ep_group
)
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=self.parallel_config.tp_group
)
self.insert_step = False
req_dicts = None
@@ -315,7 +324,9 @@ class PaddleDisWorkerProc:
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(f"Rank: {self.local_rank} has updated parameters.")
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager,
)
@@ -327,6 +338,7 @@ class PaddleDisWorkerProc:
self.parallel_config.engine_worker_queue_port,
)
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")