[BugFix] fix instability after clearing weight (#5493)

* [BugFix] fix instability after clearing weight

* [chore] add todo
This commit is contained in:
Yonghua Li
2025-12-11 10:22:35 +08:00
committed by GitHub
parent d79438bb86
commit 2ec76352da
2 changed files with 23 additions and 15 deletions

View File

@@ -62,17 +62,18 @@ class DynamicWeightManager:
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
"""Core method to update model parameters based on strategy."""
start_time = time.perf_counter()
paddle.device.cuda.empty_cache()
# step1 : restart paddle process group
# if not self.first_load:
# paddle.distributed.restart_process_group()
# paddle.distributed.restart_process_group(self.parallel_config.tp_group)
# if self.parallel_config.enable_expert_parallel:
# paddle.distributed.restart_process_group(self.parallel_config.ep_group)
if not self.first_load:
if restart_process_group:
paddle.distributed.restart_process_group()
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
# step2 : recreat deepep buffer when enable expert parallel
if self.parallel_config.enable_expert_parallel and not self.first_load:
@@ -123,7 +124,7 @@ class DynamicWeightManager:
self._update_model_from_state(state_dict, "raw")
logger.info(f"IPC update parameters completed from file: {self.ipc_path}")
def clear_parameters(self, pid: int = 0) -> None:
def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
@@ -135,8 +136,9 @@ class DynamicWeightManager:
DeepEPBufferManager.clear_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# shutdown ep group
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if shutdown_process_group:
# shutdown ep group
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.device.cuda.empty_cache()
# step2: release model weight
@@ -149,11 +151,14 @@ class DynamicWeightManager:
if self.parallel_config.tensor_parallel_size > 1:
# tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group)
# paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if shutdown_process_group:
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
# paddle.distributed.shutdown_process_group()
if shutdown_process_group:
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if shutdown_process_group:
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):

View File

@@ -418,6 +418,7 @@ class PaddleDisWorkerProc:
num_running_requests = 0
tp_rank = self.local_rank % tp_size
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
@@ -459,7 +460,7 @@ class PaddleDisWorkerProc:
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.debug(
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
from fastdeploy.rl.dynamic_weight_manager import (
@@ -473,10 +474,12 @@ class PaddleDisWorkerProc:
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
)
logger.debug(f"current task queue data: {self.task_queue.num_tasks()}")
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.debug(f"Rank: {self.local_rank} has updated or cleared parameters.")
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")