[BugFix] fix instability after clearing weight (#5487)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [BugFix] fix instability after clearing weight

* [chore] add todo
This commit is contained in:
Yonghua Li
2025-12-11 09:58:18 +08:00
committed by GitHub
parent bcde798098
commit 7019afbb86
2 changed files with 23 additions and 15 deletions

View File

@@ -62,17 +62,18 @@ class DynamicWeightManager:
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
"""Core method to update model parameters based on strategy."""
start_time = time.perf_counter()
paddle.device.cuda.empty_cache()
# step1 : restart paddle process group
# if not self.first_load:
# paddle.distributed.restart_process_group()
# paddle.distributed.restart_process_group(self.parallel_config.tp_group)
# if self.parallel_config.enable_expert_parallel:
# paddle.distributed.restart_process_group(self.parallel_config.ep_group)
if not self.first_load:
if restart_process_group:
paddle.distributed.restart_process_group()
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
# step2 : recreat deepep buffer when enable expert parallel
if self.parallel_config.enable_expert_parallel and not self.first_load:
@@ -132,7 +133,7 @@ class DynamicWeightManager:
self._update_model_from_state(state_dict, "raw")
logger.info(f"IPC update parameters completed from file: {self.ipc_path}")
def clear_parameters(self, pid: int = 0) -> None:
def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
@@ -144,8 +145,9 @@ class DynamicWeightManager:
DeepEPBufferManager.clear_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# shutdown ep group
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if shutdown_process_group:
# shutdown ep group
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.device.cuda.empty_cache()
# step2: release model weight
@@ -158,11 +160,14 @@ class DynamicWeightManager:
if self.parallel_config.tensor_parallel_size > 1:
# tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group)
# paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if shutdown_process_group:
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
# paddle.distributed.shutdown_process_group()
if shutdown_process_group:
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if shutdown_process_group:
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):

View File

@@ -418,6 +418,7 @@ class PaddleDisWorkerProc:
num_running_requests = 0
tp_rank = self.local_rank % tp_size
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
@@ -459,7 +460,7 @@ class PaddleDisWorkerProc:
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.debug(
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
from fastdeploy.rl.dynamic_weight_manager import (
@@ -473,10 +474,12 @@ class PaddleDisWorkerProc:
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
)
logger.debug(f"current task queue data: {self.task_queue.num_tasks()}")
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.debug(f"Rank: {self.local_rank} has updated or cleared parameters.")
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")