mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] fix instability after clearing weight (#5487)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [BugFix] fix instability after clearing weight * [chore] add todo
This commit is contained in:
@@ -62,17 +62,18 @@ class DynamicWeightManager:
|
||||
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
|
||||
def update_parameters(self, pid: int = 0) -> None:
|
||||
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
start_time = time.perf_counter()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
# step1 : restart paddle process group
|
||||
# if not self.first_load:
|
||||
# paddle.distributed.restart_process_group()
|
||||
# paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
# if self.parallel_config.enable_expert_parallel:
|
||||
# paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
if not self.first_load:
|
||||
if restart_process_group:
|
||||
paddle.distributed.restart_process_group()
|
||||
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
|
||||
# step2 : recreat deepep buffer when enable expert parallel
|
||||
if self.parallel_config.enable_expert_parallel and not self.first_load:
|
||||
@@ -132,7 +133,7 @@ class DynamicWeightManager:
|
||||
self._update_model_from_state(state_dict, "raw")
|
||||
logger.info(f"IPC update parameters completed from file: {self.ipc_path}")
|
||||
|
||||
def clear_parameters(self, pid: int = 0) -> None:
|
||||
def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
|
||||
"""Clear all model parameters and free memory."""
|
||||
|
||||
logger.info("start clear paramaters")
|
||||
@@ -144,8 +145,9 @@ class DynamicWeightManager:
|
||||
DeepEPBufferManager.clear_buffer()
|
||||
# ep barrier
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
# shutdown ep group
|
||||
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
if shutdown_process_group:
|
||||
# shutdown ep group
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
# step2: release model weight
|
||||
@@ -158,11 +160,14 @@ class DynamicWeightManager:
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# tp barrier
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
# paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
if shutdown_process_group:
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
# paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
# paddle.distributed.shutdown_process_group()
|
||||
if shutdown_process_group:
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
if shutdown_process_group:
|
||||
paddle.distributed.shutdown_process_group()
|
||||
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
|
||||
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
|
||||
|
||||
@@ -418,6 +418,7 @@ class PaddleDisWorkerProc:
|
||||
num_running_requests = 0
|
||||
tp_rank = self.local_rank % tp_size
|
||||
|
||||
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
|
||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
# run eplb
|
||||
@@ -459,7 +460,7 @@ class PaddleDisWorkerProc:
|
||||
else:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
|
||||
)
|
||||
from fastdeploy.rl.dynamic_weight_manager import (
|
||||
@@ -473,10 +474,12 @@ class PaddleDisWorkerProc:
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
logger.debug(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
self.task_queue.clear_data()
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
logger.debug(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
|
||||
Reference in New Issue
Block a user