mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 20:54:03 +08:00
[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -25,6 +25,7 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.inter_communicator import ModelWeightsStatus
|
||||
|
||||
|
||||
class DynamicWeightManager:
|
||||
@@ -59,6 +60,7 @@ class DynamicWeightManager:
|
||||
|
||||
def update_parameters(self, pid: int = 0) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
logger.info(f"start update paramaters: suffix={pid} rank={self.rank}")
|
||||
start_time = time.perf_counter()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
@@ -106,7 +108,7 @@ class DynamicWeightManager:
|
||||
|
||||
def clear_parameters(self, pid: int = 0) -> None:
|
||||
"""Clear all model parameters and free memory."""
|
||||
logger.info("start clear paramaters")
|
||||
logger.info(f"start clear paramaters: suffix={pid} rank={self.rank}")
|
||||
paddle.device.cuda.empty_cache()
|
||||
for param in self.model.state_dict().values():
|
||||
param._clear_data()
|
||||
@@ -119,7 +121,7 @@ class DynamicWeightManager:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group()
|
||||
self._update_shared_status(pid, -2)
|
||||
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
|
||||
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
|
||||
"""Update model parameters from given state dictionary."""
|
||||
@@ -150,7 +152,7 @@ class DynamicWeightManager:
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if not self.first_load:
|
||||
self._update_shared_status(pid, 0)
|
||||
self._update_shared_status(pid, ModelWeightsStatus.NORMAL)
|
||||
self.first_load = False
|
||||
|
||||
def _get_gpu_id(self) -> int:
|
||||
@@ -217,20 +219,20 @@ class DynamicWeightManager:
|
||||
"""
|
||||
check model weights status
|
||||
"""
|
||||
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||
is_stop = 0
|
||||
while model_weights_status.value[0] != 0:
|
||||
if model_weights_status.value[0] == 1:
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||
model_runner.update_parameters(pid)
|
||||
elif model_weights_status.value[0] == -1:
|
||||
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
||||
logger.info("infer engine stopped! start to clear checkpoint...")
|
||||
model_runner.clear_parameters(pid)
|
||||
|
||||
while True:
|
||||
if model_weights_status.value[0] == 0:
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.NORMAL:
|
||||
logger.info("finished loading new checkpoint")
|
||||
break
|
||||
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
|
||||
elif is_stop == 1 or (model_weights_status.value[0] == ModelWeightsStatus.CLEARED and is_stop == 0):
|
||||
if is_stop == 0:
|
||||
logger.info("finished clearing checkpoint")
|
||||
is_stop = 1
|
||||
|
||||
Reference in New Issue
Block a user