mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Cherry-Pick][BugFix] fix rl model_weights_signal to support tp>1 #5639 (#5637)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
This commit is contained in:
@@ -423,17 +423,11 @@ class PaddleDisWorkerProc:
|
||||
while True:
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
if tp_rank == 0:
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||
src=0, group=self.parallel_config.ep_group
|
||||
)
|
||||
if self.fd_config.load_config.dynamic_load_weight and tp_size > 1:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||
src=0, group=self.parallel_config.tp_group
|
||||
)
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
|
||||
|
||||
self.insert_step = False
|
||||
req_dicts = None
|
||||
@@ -455,11 +449,8 @@ class PaddleDisWorkerProc:
|
||||
self._tp_barrier_wait()
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
else:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
||||
paddle.distributed.barrier()
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user