[feat] support prefix cache clearing when /clear_load_weight is called (#4008)

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix key/value_cache_scales indent

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* [fix] fix max_num_seqs config

* [fix] do not force enable_prefix_caching=False in dynamic loading

* [fix] fix ci

* Revert "[fix] fix ci"

This reverts commit 0bc6d55cc8.

* [fix] initialize available_gpu_block_num with max_gpu_block_num

* [fix] fix config splitwise_role

* [fix] fix clearing caches synchronization and add more logs

* [chore] print cache_ready_signal in log

* [fix] fix scheduler_config.splitwise_role

* [fix] fix cache_messager cache_ready_signal create=True

* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
李泳桦
2025-09-28 19:42:53 +08:00
committed by GitHub
parent 59313ed7f9
commit 6265f4385f
20 changed files with 697 additions and 213 deletions

View File

@@ -42,7 +42,7 @@ from fastdeploy.config import (
)
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
@@ -183,7 +183,7 @@ class PaddleDisWorkerProc:
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
@@ -200,7 +200,7 @@ class PaddleDisWorkerProc:
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
@@ -279,8 +279,8 @@ class PaddleDisWorkerProc:
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
if local_rank == 0:
if self.model_weights_status.value[0] != 0:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
@@ -306,7 +306,7 @@ class PaddleDisWorkerProc:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = 1
self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers
@@ -317,7 +317,7 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier(self.parallel_config.ep_group)
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != 0:
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
@@ -332,17 +332,17 @@ class PaddleDisWorkerProc:
self.worker.model_runner,
self.parallel_config.engine_worker_queue_port,
)
self.model_weights_signal[0] = 0
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[0] = 0
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
self.task_queue.read_finish_flag.set(0)
req_dicts = []
@@ -418,25 +418,14 @@ class PaddleDisWorkerProc:
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
@@ -488,7 +477,7 @@ class PaddleDisWorkerProc:
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
if self.ranks > 1: