mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4008)
* [feat] support clearing prefix cache (cherry-picked from release/2.1)
* [fix] fix ipc suffix, use port instead
* [fix] fix prefix caching not enabled
* [fix] fix key/value_cache_scales indent
* [fix] fix ep group all-reduce
* [fix] fix clear/update lock not working when workers > 1
* [chore] add preemption triggered info log
* [fix] fix code style
* [fix] fix max_num_seqs config
* [fix] do not force enable_prefix_caching=False in dynamic loading
* [fix] fix ci
* Revert "[fix] fix ci"
This reverts commit 0bc6d55cc8
.
* [fix] initialize available_gpu_block_num with max_gpu_block_num
* [fix] fix config splitwise_role
* [fix] fix clearing caches synchronization and add more logs
* [chore] print cache_ready_signal in log
* [fix] fix scheduler_config.splitwise_role
* [fix] fix cache_messager cache_ready_signal create=True
* [fix] stop cache messager from launching in mixed deployment
This commit is contained in:
@@ -31,7 +31,7 @@ import numpy as np
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
@@ -71,6 +71,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
|
||||
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
|
||||
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
|
||||
if self.num_cpu_blocks > 0:
|
||||
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
|
||||
@@ -78,6 +79,7 @@ class PrefixCacheManager:
|
||||
self.cpu_free_block_list = []
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
heapq.heapify(self.cpu_free_block_list)
|
||||
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
@@ -111,6 +113,10 @@ class PrefixCacheManager:
|
||||
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
|
||||
)
|
||||
|
||||
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
@property
|
||||
def available_gpu_resource(self):
|
||||
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
|
||||
@@ -123,6 +129,7 @@ class PrefixCacheManager:
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
create_cache_tensor,
|
||||
):
|
||||
"""
|
||||
launch_cache_manager function used to initialize the cache manager.
|
||||
@@ -133,7 +140,7 @@ class PrefixCacheManager:
|
||||
name="cache_task_broadcast_signal",
|
||||
array=broadcast_cache_task_flag_array,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=True,
|
||||
)
|
||||
|
||||
@@ -151,17 +158,18 @@ class PrefixCacheManager:
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
|
||||
cache_messager_processes = []
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
if self.enable_splitwise:
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
@@ -173,20 +181,41 @@ class PrefixCacheManager:
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
kv_num_head = max(1, kv_num_head)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
swap_space_ready_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.swap_space_ready_signal = IPCSignal(
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_data,
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
prefix_tree_status = np.zeros([1], dtype=np.int32)
|
||||
self.prefix_tree_status_signal = IPCSignal(
|
||||
name="prefix_tree_status",
|
||||
array=prefix_tree_status,
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# Run command to launch cache transfer managers
|
||||
logger.info(f"create_cache_tensor: {create_cache_tensor}")
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
@@ -209,23 +238,33 @@ class PrefixCacheManager:
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
# 等待cache初始化完毕
|
||||
logger.info("Waiting for cache transfer manager ready...")
|
||||
|
||||
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
threading.Thread(target=self.recv_data_transfer_result).start()
|
||||
if cache_config.enable_prefix_caching:
|
||||
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
|
||||
|
||||
all_cache_processes = cache_messager_processes + cache_manager_processes
|
||||
return all_cache_processes
|
||||
|
||||
@@ -253,7 +292,7 @@ class PrefixCacheManager:
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
create=False,
|
||||
)
|
||||
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
@@ -286,6 +325,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
|
||||
logger.info("Waiting for cache ready...")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
@@ -317,23 +357,9 @@ class PrefixCacheManager:
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
def _enable_cpu_cache(self):
|
||||
"""
|
||||
_enable_cpu_cache function used to enable cpu cache.
|
||||
"""
|
||||
|
||||
# ipc_cache_queue_port = self.cache_config.cache_queue_port
|
||||
# self.cache_task_queue = CacheQueueManager(
|
||||
# rank=0,
|
||||
# mp_num=tensor_parallel_size,
|
||||
# port=ipc_cache_queue_port,
|
||||
# )
|
||||
# 开启获取传输任务结果的监听线程
|
||||
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
|
||||
self.transfer_recv_thread.start()
|
||||
|
||||
def can_allocate_gpu_blocks(self, num_blocks: int):
|
||||
"""
|
||||
Check if num_blocks gpu blocks can be allocated.
|
||||
@@ -1377,3 +1403,70 @@ class PrefixCacheManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the RadixTree.
|
||||
"""
|
||||
|
||||
if len(self.node_map) == 0:
|
||||
return
|
||||
|
||||
logger.info("Resetting the RadixTree!")
|
||||
|
||||
# wait for swap tasks to finish
|
||||
if self.gpu_free_task_future is not None:
|
||||
self.gpu_free_task_future.result()
|
||||
self.gpu_free_task_future = None
|
||||
for event in list(self.task_swapping_event.values()):
|
||||
event.wait()
|
||||
self.task_swapping_event.clear()
|
||||
|
||||
# clear node map
|
||||
self.node_map.clear()
|
||||
self.req_leaf_map.clear()
|
||||
self.leaf_req_map.clear()
|
||||
self.unfilled_req_block_map.clear()
|
||||
self.cache_info.clear()
|
||||
|
||||
# reset gpu cache data structure
|
||||
self.gpu_lru_leaf_heap.clear()
|
||||
self.gpu_lru_leaf_set.clear()
|
||||
|
||||
# reset cpu cache data structure
|
||||
self.cpu_lru_leaf_heap.clear()
|
||||
self.cpu_lru_leaf_set.clear()
|
||||
|
||||
# reset gpu/cpu free block list
|
||||
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
|
||||
if self.num_cpu_blocks > 0:
|
||||
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
|
||||
else:
|
||||
self.cpu_free_block_list = []
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
heapq.heapify(self.cpu_free_block_list)
|
||||
|
||||
# reset node/tree
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
|
||||
# reset metrics
|
||||
self.metrics.reset_metrics()
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
|
||||
def clear_prefix_cache(self):
|
||||
"""
|
||||
If the model weights status is updating or clearing, reset prefix cache tree
|
||||
"""
|
||||
logger.info("Start a thread to clear prefix cache when model weights are cleared.")
|
||||
prefix_tree_status_signal = self.prefix_tree_status_signal
|
||||
while True:
|
||||
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING:
|
||||
self.reset()
|
||||
prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED
|
||||
logger.info("Prefix cache tree is cleared.")
|
||||
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING:
|
||||
prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL
|
||||
logger.info("Prefix cache tree is updated.")
|
||||
time.sleep(0.01)
|
||||
|
Reference in New Issue
Block a user