[feat] support prefix cache clearing when /clear_load_weight is called (#4008)

* [feat] support clearing prefix cache (cherry-picked from release/2.1)

* [fix] fix ipc suffix, use port instead

* [fix] fix prefix caching not enabled

* [fix] fix key/value_cache_scales indent

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* [fix] fix max_num_seqs config

* [fix] do not force enable_prefix_caching=False in dynamic loading

* [fix] fix ci

* Revert "[fix] fix ci"

This reverts commit 0bc6d55cc8.

* [fix] initialize available_gpu_block_num with max_gpu_block_num

* [fix] fix config splitwise_role

* [fix] fix clearing caches synchronization and add more logs

* [chore] print cache_ready_signal in log

* [fix] fix scheduler_config.splitwise_role

* [fix] fix cache_messager cache_ready_signal create=True

* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
李泳桦
2025-09-28 19:42:53 +08:00
committed by GitHub
parent 59313ed7f9
commit 6265f4385f
20 changed files with 697 additions and 213 deletions

View File

@@ -31,7 +31,7 @@ import numpy as np
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
@@ -71,6 +71,7 @@ class PrefixCacheManager:
else:
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
@@ -78,6 +79,7 @@ class PrefixCacheManager:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
@@ -111,6 +113,10 @@ class PrefixCacheManager:
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
)
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
@property
def available_gpu_resource(self):
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
@@ -123,6 +129,7 @@ class PrefixCacheManager:
pod_ip,
engine_worker_queue_port,
pid_suffix,
create_cache_tensor,
):
"""
launch_cache_manager function used to initialize the cache manager.
@@ -133,7 +140,7 @@ class PrefixCacheManager:
name="cache_task_broadcast_signal",
array=broadcast_cache_task_flag_array,
dtype=np.int32,
suffix=pid_suffix,
suffix=engine_worker_queue_port,
create=True,
)
@@ -151,17 +158,18 @@ class PrefixCacheManager:
py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if self.enable_splitwise:
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
@@ -173,20 +181,41 @@ class PrefixCacheManager:
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
kv_num_head = max(1, kv_num_head)
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
suffix=engine_worker_queue_port,
create=False,
)
swap_space_ready_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=engine_worker_queue_port,
create=False,
)
# Run command to launch cache transfer managers
logger.info(f"create_cache_tensor: {create_cache_tensor}")
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
@@ -209,23 +238,33 @@ class PrefixCacheManager:
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
threading.Thread(target=self.recv_data_transfer_result).start()
if cache_config.enable_prefix_caching:
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
all_cache_processes = cache_messager_processes + cache_manager_processes
return all_cache_processes
@@ -253,7 +292,7 @@ class PrefixCacheManager:
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
create=False,
)
py_path = os.path.join(current_dir_path, filename)
@@ -286,6 +325,7 @@ class PrefixCacheManager:
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
@@ -317,23 +357,9 @@ class PrefixCacheManager:
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
"""
# ipc_cache_queue_port = self.cache_config.cache_queue_port
# self.cache_task_queue = CacheQueueManager(
# rank=0,
# mp_num=tensor_parallel_size,
# port=ipc_cache_queue_port,
# )
# 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
self.transfer_recv_thread.start()
def can_allocate_gpu_blocks(self, num_blocks: int):
"""
Check if num_blocks gpu blocks can be allocated.
@@ -1377,3 +1403,70 @@ class PrefixCacheManager:
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e
def reset(self):
"""
Reset the RadixTree.
"""
if len(self.node_map) == 0:
return
logger.info("Resetting the RadixTree!")
# wait for swap tasks to finish
if self.gpu_free_task_future is not None:
self.gpu_free_task_future.result()
self.gpu_free_task_future = None
for event in list(self.task_swapping_event.values()):
event.wait()
self.task_swapping_event.clear()
# clear node map
self.node_map.clear()
self.req_leaf_map.clear()
self.leaf_req_map.clear()
self.unfilled_req_block_map.clear()
self.cache_info.clear()
# reset gpu cache data structure
self.gpu_lru_leaf_heap.clear()
self.gpu_lru_leaf_set.clear()
# reset cpu cache data structure
self.cpu_lru_leaf_heap.clear()
self.cpu_lru_leaf_set.clear()
# reset gpu/cpu free block list
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
else:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
# reset node/tree
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# reset metrics
self.metrics.reset_metrics()
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def clear_prefix_cache(self):
"""
If the model weights status is updating or clearing, reset prefix cache tree
"""
logger.info("Start a thread to clear prefix cache when model weights are cleared.")
prefix_tree_status_signal = self.prefix_tree_status_signal
while True:
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING:
self.reset()
prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED
logger.info("Prefix cache tree is cleared.")
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING:
prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL
logger.info("Prefix cache tree is updated.")
time.sleep(0.01)