mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-14 04:44:00 +08:00
[feat] support prefix cache clearing when /clear_load_weight
is called (#4091)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
This commit is contained in:
@@ -98,8 +98,8 @@ class CacheMessager:
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
|
@@ -16,21 +16,27 @@
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
@@ -93,6 +99,7 @@ def parse_args():
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
parser.add_argument("--create_cache_tensor", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -110,7 +117,6 @@ class CacheTransferManager:
|
||||
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
@@ -126,6 +132,7 @@ class CacheTransferManager:
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
self.engine_pid = args.engine_pid
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -136,70 +143,27 @@ class CacheTransferManager:
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
)
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
cache_type = args.cache_dtype
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[
|
||||
num_gpu_blocks,
|
||||
args.kv_num_head,
|
||||
args.block_size,
|
||||
args.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
|
||||
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
set_data_ipc(
|
||||
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
||||
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
|
||||
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
|
||||
args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
suffix=self.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.swap_space_ready_signal = IPCSignal(
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
self._init_cpu_cache(args)
|
||||
self._init_gpu_cache(args)
|
||||
|
||||
paddle.set_device(f"gpu:{device}")
|
||||
if args.enable_splitwise:
|
||||
@@ -232,6 +196,72 @@ class CacheTransferManager:
|
||||
create=False,
|
||||
)
|
||||
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
if not args.create_cache_tensor:
|
||||
logger.info("Waiting for runners to create kv cache.")
|
||||
while self.cache_ready_signal.value[self.rank] != 1:
|
||||
time.sleep(1)
|
||||
logger.info("OK! Stop waiting.")
|
||||
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
else:
|
||||
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
key_cache = share_external_data(key_cache, key_name, cache_shape)
|
||||
val_cache = share_external_data(val_cache, val_name, cache_shape)
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info("✅ kv cache is ready!")
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{self.device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
if args.num_cpu_blocks == 0:
|
||||
logger.info("💡 no swap space (cpu cache) is specified.")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
return
|
||||
logger.info("Initializing swap space (cpu cache) for all layers.")
|
||||
paddle.set_device("cpu")
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
logger.info("✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
self,
|
||||
swap_node_ids,
|
||||
@@ -429,6 +459,67 @@ class CacheTransferManager:
|
||||
transfer_task_id,
|
||||
)
|
||||
|
||||
def clear_or_update_caches(self, args):
|
||||
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
|
||||
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
|
||||
kv_cache_status = np.zeros([1], dtype=np.int32)
|
||||
kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status,
|
||||
dtype=np.int32,
|
||||
suffix=self.engine_pid,
|
||||
create=False,
|
||||
)
|
||||
while True:
|
||||
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
try:
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
paddle.set_device("cpu")
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
self.cpu_cache_kvs.clear()
|
||||
self.k_dst_ptrs.clear()
|
||||
self.v_dst_ptrs.clear()
|
||||
gc.collect()
|
||||
# reset swap_space_ready_signal
|
||||
self.swap_space_ready_signal.value[self.rank] = 0
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
# reset cache_ready_signal
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
if np.sum(self.cache_ready_signal.value) == 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear caches: {e}")
|
||||
|
||||
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
try:
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache(args)
|
||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
self._init_gpu_cache(args)
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
|
@@ -31,7 +31,7 @@ import numpy as np
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
@@ -71,6 +71,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
|
||||
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
|
||||
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
|
||||
if self.num_cpu_blocks > 0:
|
||||
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
|
||||
@@ -78,6 +79,7 @@ class PrefixCacheManager:
|
||||
self.cpu_free_block_list = []
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
heapq.heapify(self.cpu_free_block_list)
|
||||
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
@@ -123,6 +125,7 @@ class PrefixCacheManager:
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
create_cache_tensor,
|
||||
):
|
||||
"""
|
||||
launch_cache_manager function used to initialize the cache manager.
|
||||
@@ -133,7 +136,7 @@ class PrefixCacheManager:
|
||||
name="cache_task_broadcast_signal",
|
||||
array=broadcast_cache_task_flag_array,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=True,
|
||||
)
|
||||
|
||||
@@ -160,20 +163,41 @@ class PrefixCacheManager:
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
kv_num_head = max(1, kv_num_head)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
create=True,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
swap_space_ready_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.swap_space_ready_signal = IPCSignal(
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_data,
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
prefix_tree_status = np.zeros([1], dtype=np.int32)
|
||||
self.prefix_tree_status_signal = IPCSignal(
|
||||
name="prefix_tree_status",
|
||||
array=prefix_tree_status,
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# Run command to launch cache transfer managers
|
||||
logger.info(f"create_cache_tensor: {create_cache_tensor}")
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
@@ -196,23 +220,33 @@ class PrefixCacheManager:
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
# 等待cache初始化完毕
|
||||
logger.info("Waiting for cache transfer manager ready...")
|
||||
|
||||
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
self._enable_cpu_cache()
|
||||
threading.Thread(target=self.recv_data_transfer_result).start()
|
||||
if cache_config.enable_prefix_caching:
|
||||
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
|
||||
|
||||
return cache_manager_processes
|
||||
|
||||
def update_cache_config(self, cache_config):
|
||||
@@ -237,21 +271,6 @@ class PrefixCacheManager:
|
||||
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
def _enable_cpu_cache(self):
|
||||
"""
|
||||
_enable_cpu_cache function used to enable cpu cache.
|
||||
"""
|
||||
|
||||
# ipc_cache_queue_port = self.cache_config.cache_queue_port
|
||||
# self.cache_task_queue = CacheQueueManager(
|
||||
# rank=0,
|
||||
# mp_num=tensor_parallel_size,
|
||||
# port=ipc_cache_queue_port,
|
||||
# )
|
||||
# 开启获取传输任务结果的监听线程
|
||||
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
|
||||
self.transfer_recv_thread.start()
|
||||
|
||||
def can_allocate_gpu_blocks(self, num_blocks: int):
|
||||
"""
|
||||
Check if num_blocks gpu blocks can be allocated.
|
||||
@@ -1295,3 +1314,70 @@ class PrefixCacheManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the RadixTree.
|
||||
"""
|
||||
|
||||
if len(self.node_map) == 0:
|
||||
return
|
||||
|
||||
logger.info("Resetting the RadixTree!")
|
||||
|
||||
# wait for swap tasks to finish
|
||||
if self.gpu_free_task_future is not None:
|
||||
self.gpu_free_task_future.result()
|
||||
self.gpu_free_task_future = None
|
||||
for event in list(self.task_swapping_event.values()):
|
||||
event.wait()
|
||||
self.task_swapping_event.clear()
|
||||
|
||||
# clear node map
|
||||
self.node_map.clear()
|
||||
self.req_leaf_map.clear()
|
||||
self.leaf_req_map.clear()
|
||||
self.unfilled_req_block_map.clear()
|
||||
self.cache_info.clear()
|
||||
|
||||
# reset gpu cache data structure
|
||||
self.gpu_lru_leaf_heap.clear()
|
||||
self.gpu_lru_leaf_set.clear()
|
||||
|
||||
# reset cpu cache data structure
|
||||
self.cpu_lru_leaf_heap.clear()
|
||||
self.cpu_lru_leaf_set.clear()
|
||||
|
||||
# reset gpu/cpu free block list
|
||||
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
|
||||
if self.num_cpu_blocks > 0:
|
||||
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
|
||||
else:
|
||||
self.cpu_free_block_list = []
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
heapq.heapify(self.cpu_free_block_list)
|
||||
|
||||
# reset node/tree
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
|
||||
# reset metrics
|
||||
self.metrics.reset_metrics()
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
|
||||
def clear_prefix_cache(self):
|
||||
"""
|
||||
If the model weights status is updating or clearing, reset prefix cache tree
|
||||
"""
|
||||
logger.info("Start a thread to clear prefix cache when model weights are cleared.")
|
||||
prefix_tree_status_signal = self.prefix_tree_status_signal
|
||||
while True:
|
||||
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARING:
|
||||
self.reset()
|
||||
prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARED
|
||||
logger.info("Prefix cache tree is cleared.")
|
||||
if prefix_tree_status_signal.value[0] == PrefixTreeStatus.UPDATING:
|
||||
prefix_tree_status_signal.value[0] = PrefixTreeStatus.NORMAL
|
||||
logger.info("Prefix cache tree is updated.")
|
||||
time.sleep(0.01)
|
||||
|
Reference in New Issue
Block a user