[feat] support prefix cache clearing when /clear_load_weight is called (#4008)

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix key/value_cache_scales indent

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* [fix] fix max_num_seqs config

* [fix] do not force enable_prefix_caching=False in dynamic loading

* [fix] fix ci

* Revert "[fix] fix ci"

This reverts commit 0bc6d55cc8.

* [fix] initialize available_gpu_block_num with max_gpu_block_num

* [fix] fix config splitwise_role

* [fix] fix clearing caches synchronization and add more logs

* [chore] print cache_ready_signal in log

* [fix] fix scheduler_config.splitwise_role

* [fix] fix cache_messager cache_ready_signal create=True

* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
李泳桦
2025-09-28 19:42:53 +08:00
committed by GitHub
parent 59313ed7f9
commit 6265f4385f
20 changed files with 697 additions and 213 deletions

View File

@@ -25,6 +25,7 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import ModelWeightsStatus
class DynamicWeightManager:
@@ -143,12 +144,11 @@ class DynamicWeightManager:
if self.parallel_config.tensor_parallel_size > 1:
# tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group)
# shutdown tp group
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
# step3: update model weight signal
# step4: release kv cache in the runner
self._update_shared_status(pid, -2)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
"""Update model parameters from given state dictionary."""
@@ -184,8 +184,7 @@ class DynamicWeightManager:
paddle.distributed.barrier(self.parallel_config.ep_group)
if not self.first_load:
self._update_shared_status(pid, 0)
self._update_shared_status(pid, ModelWeightsStatus.NORMAL)
self.first_load = False
def _get_gpu_id(self) -> int:
@@ -252,25 +251,19 @@ class DynamicWeightManager:
"""
check model weights status
"""
is_stop = 0
while model_weights_status.value[0] != 0:
if model_weights_status.value[0] == 1:
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.update_parameters(pid)
elif model_weights_status.value[0] == -1:
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
time.sleep(0.01)
logger.info("finished loading new checkpoint")
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
logger.info("infer engine stopped! start to clear checkpoint...")
model_runner.clear_requests()
model_runner.clear_parameters(pid)
while True:
if model_weights_status.value[0] == 0:
logger.info("finished loading new checkpoint")
break
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint")
is_stop = 1
time.sleep(0.001)
break
else:
time.sleep(0.001)
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
time.sleep(0.01)
logger.info("finished clearing checkpoint")
time.sleep(0.01)