mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4008)
* [feat] support clearing prefix cache (cherry-picked from release/2.1)
* [fix] fix ipc suffix, use port instead
* [fix] fix prefix caching not enabled
* [fix] fix key/value_cache_scales indent
* [fix] fix ep group all-reduce
* [fix] fix clear/update lock not working when workers > 1
* [chore] add preemption triggered info log
* [fix] fix code style
* [fix] fix max_num_seqs config
* [fix] do not force enable_prefix_caching=False in dynamic loading
* [fix] fix ci
* Revert "[fix] fix ci"
This reverts commit 0bc6d55cc8
.
* [fix] initialize available_gpu_block_num with max_gpu_block_num
* [fix] fix config splitwise_role
* [fix] fix clearing caches synchronization and add more logs
* [chore] print cache_ready_signal in log
* [fix] fix scheduler_config.splitwise_role
* [fix] fix cache_messager cache_ready_signal create=True
* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
@@ -25,6 +25,7 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.inter_communicator import ModelWeightsStatus
|
||||
|
||||
|
||||
class DynamicWeightManager:
|
||||
@@ -143,12 +144,11 @@ class DynamicWeightManager:
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# tp barrier
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
# shutdown tp group
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
|
||||
# step3: update model weight signal
|
||||
# step4: release kv cache in the runner
|
||||
self._update_shared_status(pid, -2)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
|
||||
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
|
||||
"""Update model parameters from given state dictionary."""
|
||||
@@ -184,8 +184,7 @@ class DynamicWeightManager:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
|
||||
if not self.first_load:
|
||||
self._update_shared_status(pid, 0)
|
||||
|
||||
self._update_shared_status(pid, ModelWeightsStatus.NORMAL)
|
||||
self.first_load = False
|
||||
|
||||
def _get_gpu_id(self) -> int:
|
||||
@@ -252,25 +251,19 @@ class DynamicWeightManager:
|
||||
"""
|
||||
check model weights status
|
||||
"""
|
||||
is_stop = 0
|
||||
while model_weights_status.value[0] != 0:
|
||||
if model_weights_status.value[0] == 1:
|
||||
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||
model_runner.update_parameters(pid)
|
||||
elif model_weights_status.value[0] == -1:
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
time.sleep(0.01)
|
||||
logger.info("finished loading new checkpoint")
|
||||
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
||||
logger.info("infer engine stopped! start to clear checkpoint...")
|
||||
model_runner.clear_requests()
|
||||
model_runner.clear_parameters(pid)
|
||||
|
||||
while True:
|
||||
if model_weights_status.value[0] == 0:
|
||||
logger.info("finished loading new checkpoint")
|
||||
break
|
||||
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
|
||||
if is_stop == 0:
|
||||
logger.info("finished clearing checkpoint")
|
||||
is_stop = 1
|
||||
time.sleep(0.001)
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
logger.info("finished clearing checkpoint")
|
||||
time.sleep(0.01)
|
||||
|
Reference in New Issue
Block a user