Files
FastDeploy/fastdeploy/cache_manager/prefix_cache_manager.py
chenjian 918ccdb123 [Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
2025-09-22 16:41:38 +08:00

1380 lines
57 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import heapq
import os
import subprocess
import sys
import threading
import time
import traceback
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from threading import Event, Lock
import numpy as np
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
class PrefixCacheManager:
"""
PrefixCacheManager is used to manage the prefix tree and the cache.
"""
def __init__(
self,
config,
tensor_parallel_size,
splitwise_role="mixed",
local_data_parallel_id=0,
):
"""
initialize the PrefixCacheManager
"""
self.metrics = CacheMetrics()
if splitwise_role != "mixed":
self.enable_splitwise = 1
else:
self.enable_splitwise = 0
self.splitwise_role = splitwise_role
self.cache_config = config.cache_config
self.speculative_config = config.speculative_config
self.local_data_parallel_id = local_data_parallel_id
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.num_gpu_blocks = self.cache_config.total_block_num
else:
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
else:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# gpu cache data structure
self.gpu_lru_leaf_heap = []
self.gpu_lru_leaf_set = set()
# cpu cache data structure
self.cpu_lru_leaf_heap = []
self.cpu_lru_leaf_set = set()
# swap in/out data structure
self.request_release_lock = Lock()
self.task_swapping_event = {}
self.node_map = {}
self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list)
self.cache_info = {}
self.executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_cpu_executor_pool = ThreadPoolExecutor(max_workers=1)
self.gpu_free_task_future = None
self.cache_status_lock = Lock()
logger.info(
f"num_gpu_blocks_server_owned {self.num_gpu_blocks} num_cpu_blocks "
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
)
@property
def available_gpu_resource(self):
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
def launch_cache_manager(
self,
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
):
"""
launch_cache_manager function used to initialize the cache manager.
"""
broadcast_cache_task_flag_array = np.zeros([1], dtype=np.int32)
self.shm_cache_task_flag_broadcast = IPCSignal(
name="cache_task_broadcast_signal",
array=broadcast_cache_task_flag_array,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
self.cache_task_queue = EngineCacheQueue(
address=(pod_ip, cache_config.cache_queue_port),
authkey=b"cache_queue_service",
is_server=False,
num_client=tensor_parallel_size,
client_id=0,
local_data_parallel_id=self.local_data_parallel_id,
)
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
kv_num_head = max(1, kv_num_head)
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
all_cache_processes = cache_messager_processes + cache_manager_processes
return all_cache_processes
def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --pod_ip {pod_ip}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --block_size {cache_config.block_size}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_messager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache messager successful")
else:
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
return cache_messager_processes
def update_cache_config(self, cache_config):
"""
update cache config
"""
self.cache_config = cache_config
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.num_gpu_blocks = cache_config.total_block_num
self.gpu_free_block_list = list(
range(self.num_gpu_blocks - 1, -1, -1)
) # All gpu blocks are managed by cache manager
else:
self.num_gpu_blocks = cache_config.prefill_kvcache_block_num
self.gpu_free_block_list = list(
range(self.num_gpu_blocks - 1, -1, -1)
) # Only block table divided for prefill managed by server
heapq.heapify(self.gpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
"""
# ipc_cache_queue_port = self.cache_config.cache_queue_port
# self.cache_task_queue = CacheQueueManager(
# rank=0,
# mp_num=tensor_parallel_size,
# port=ipc_cache_queue_port,
# )
# 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
self.transfer_recv_thread.start()
def can_allocate_gpu_blocks(self, num_blocks: int):
"""
Check if num_blocks gpu blocks can be allocated.
"""
if len(self.gpu_free_block_list) < num_blocks:
if self.cache_config.enable_prefix_caching:
self.free_block_ids(num_blocks)
if len(self.gpu_free_block_list) < num_blocks:
return False
else:
return True
else:
return True
def allocate_gpu_blocks(self, num_blocks):
"""
allocate gpu blocks.
"""
assert num_blocks <= len(
self.gpu_free_block_list
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
return allocated_block_ids
def recycle_gpu_blocks(self, gpu_block_ids):
"""
recycle gpu blocks.
"""
logger.info(
f"recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
if isinstance(gpu_block_ids, list):
for gpu_block_id in gpu_block_ids:
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
else:
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def allocate_cpu_blocks(self, num_blocks):
"""
allocate cpu blocks.
"""
assert num_blocks <= len(
self.cpu_free_block_list
), f"cpu free block num: {len(self.cpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [heapq.heappop(self.cpu_free_block_list) for i in range(num_blocks)]
logger.info(
f"allocate_cpu_blocks: {allocated_block_ids}, len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
)
return allocated_block_ids
def recycle_cpu_blocks(self, cpu_block_ids):
"""
recycle cpu blocks.
"""
logger.info(
f"recycle_cpu_blocks: {cpu_block_ids}, len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
)
if isinstance(cpu_block_ids, list):
for cpu_block_id in cpu_block_ids:
heapq.heappush(self.cpu_free_block_list, cpu_block_id)
else:
heapq.heappush(self.cpu_free_block_list, cpu_block_ids)
def issue_swap_task(
self,
transfer_task_id,
swap_node_ids,
gpu_block_ids,
cpu_block_ids,
event_type,
is_sync=True,
):
"""
start data swap task
args:
transfer_task_id: transfer task id
swap_node_ids: to swap node id list
gpu_block_ids: to swap gpu block id list
cpu_block_ids: to swap cpu block id list
event_type: CacheStatus.SWAP2GPU or CacheStatus.SWAP2CPU
is_sync: bool, whether to wait for the result of the swap task
"""
self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task(
(
swap_node_ids,
gpu_block_ids,
cpu_block_ids,
event_type,
transfer_task_id,
)
)
if is_sync:
self.sync_swap_task(transfer_task_id)
def sync_swap_task(self, transfer_task_id):
"""
sync swap task
"""
self.task_swapping_event[transfer_task_id].wait()
del self.task_swapping_event[transfer_task_id]
def _check_validity(self, req_id, match_gpu_blocks_num, expected_block_num):
"""
check enough gpu memory to allocate cache
"""
if expected_block_num - match_gpu_blocks_num > len(self.gpu_free_block_list):
msg = (
f"request_block_ids: request block for req_id {req_id} failed. "
+ f"matched gpu block num: {match_gpu_blocks_num} require extra gpu block num: "
+ f"{expected_block_num - match_gpu_blocks_num} > free block num: {len(self.gpu_free_block_list)}"
)
logger.info(msg)
raise Exception("Not enough GPU memory to allocate cache")
def _prepare_cpu_cache(
self,
req_id,
swap_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
):
"""
将cpu cache转移到GPU
"""
transfer_task_id = req_id
need_transfer_task_gpu_block_ids = []
need_transfer_task_cpu_block_ids = []
for tmp_gpu_block_id in gpu_recv_block_ids:
need_transfer_task_gpu_block_ids.append(tmp_gpu_block_id)
for tmp_cpu_block_id in match_cpu_block_ids:
need_transfer_task_cpu_block_ids.append(tmp_cpu_block_id)
assert len(need_transfer_task_gpu_block_ids) == len(need_transfer_task_cpu_block_ids)
logger.info(f"request_block_ids: req_id {req_id} issue_swap_task transfer_task_id {transfer_task_id}")
self.issue_swap_task(
transfer_task_id,
swap_node_ids,
need_transfer_task_gpu_block_ids,
need_transfer_task_cpu_block_ids,
CacheStatus.SWAP2GPU,
True,
)
def _prepare_cache(
self,
req_id,
input_ids,
block_size,
expected_block_num,
match_gpu_block_ids,
match_cpu_block_ids,
match_node_ids,
):
"""
prepare cache for request
"""
match_gpu_blocks_num = len(match_gpu_block_ids)
match_cpu_blocks_num = len(match_cpu_block_ids)
matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num
cpu_recv_block_ids = []
gpu_recv_block_ids = []
gpu_extra_block_ids = []
# allocate gpu cache for matched cpu blocks
if match_cpu_blocks_num > 0:
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
# allocate gpu cache
gpu_extra_block_num = expected_block_num - matched_block_num
if gpu_extra_block_num > 0:
gpu_extra_block_ids = self.allocate_gpu_blocks(gpu_extra_block_num)
if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(
req_id,
match_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
)
return gpu_recv_block_ids, gpu_extra_block_ids
def get_required_block_num(self, input_token_num, block_size):
"""
get required block num by input token num and block size
"""
return (input_token_num + block_size - 1) // block_size
def update_cache_blocks(self, task, block_size, num_computed_tokens):
"""
update cache blocks for a task.
# TODO(chengyanfu): support async update
Parameters:
- task: Task
- block_size: Size per block (in tokens)
"""
try:
req_id = task.request_id
block_tables = task.block_tables
last_node, num_cached_tokens = self.cache_info[req_id]
if isinstance(task.prompt_token_ids, np.ndarray):
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
with self.request_release_lock:
current_time = time.time()
leaf_node = self.build_path(
req_id=req_id,
current_time=current_time,
input_ids=input_ids,
left_input_ids=left_input_ids,
gpu_block_ids=gpu_extra_block_ids,
block_size=block_size,
last_node=last_node,
reverved_dec_block_num=0,
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
task.cached_block_num = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def request_match_blocks(self, task, block_size, *args):
"""
get match blocks info for a task.
This is a synchronous interface. If CPU-to-GPU data transfer occurs,
it will block until synchronization completes.
Callers requiring asynchronous behavior should invoke this via a thread pool.
Note: This function may allocate GPU blocks for matched CPU Cache
Parameters:
- task: Task dictionary
- block_size: Size per block (in tokens)
Returns:
- common_block_ids: List of matched shared blocks
- unique_block_ids: List of exclusively allocated blocks
"""
with self.request_release_lock:
try:
hit_info = {}
hit_info["gpu_cache_blocks"] = 0
hit_info["cpu_cache_blocks"] = 0
self.metrics.req_count += 1
if isinstance(task.prompt_token_ids, np.ndarray):
prompt_token_ids = task.prompt_token_ids.tolist()
else:
prompt_token_ids = task.prompt_token_ids
input_ids = prompt_token_ids + task.output_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids)
common_block_ids = []
# 1. match block
(
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
match_block_node,
gpu_match_token_num,
cpu_match_token_num,
) = self.match_block(req_id, input_ids, block_size)
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
# 2. prepare cache
# allocate gpu cache for matched cpu blocks
gpu_recv_block_ids = []
match_cpu_blocks_num = len(match_cpu_block_ids)
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
if match_cpu_blocks_num > 0:
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(
req_id=req_id,
swap_node_ids=swap_node_ids,
gpu_recv_block_ids=gpu_recv_block_ids,
match_cpu_block_ids=match_cpu_block_ids,
cpu_recv_block_ids=[],
)
else:
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)
# 3. update metrics
matched_token_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
if matched_token_num > 0:
self.metrics.hit_req_count += 1
self.metrics.calculate_hit_metrics(
req_id,
cpu_match_token_num,
gpu_match_token_num,
input_token_num,
)
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
)
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, matched_token_num)
task.cached_block_num = matched_token_num // block_size
return common_block_ids, matched_token_num, hit_info
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
def request_block_ids(self, task, block_size, dec_token_num, *args):
"""
Allocate blocks for a task.
This is a synchronous interface. If CPU-to-GPU data transfer occurs,
it will block until synchronization completes.
Callers requiring asynchronous behavior should invoke this via a thread pool.
Parameters:
- task: Task dictionary
- block_size: Size per block (in tokens)
- dec_token_num: Number of tokens reserved for decoding on the server side
Returns:
- common_block_ids: List of matched shared blocks
- unique_block_ids: List of exclusively allocated blocks
"""
with self.request_release_lock:
try:
hit_info = {}
hit_info["gpu_cache_blocks"] = 0
hit_info["cpu_cache_blocks"] = 0
self.metrics.req_count += 1
input_ids = task.prompt_token_ids
req_id = task.request_id
logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids)
common_block_ids = []
unique_block_ids = []
# 1. match block
(
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
match_block_node,
gpu_match_token_num,
cpu_match_token_num,
) = self.match_block(req_id, input_ids, block_size)
match_gpu_blocks_num = len(match_gpu_block_ids)
matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num
# check enough gpu memory to allocate cache
block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size
self._check_validity(req_id, match_gpu_blocks_num, block_num)
# update matched node info
current_time = time.time()
self._update_matched_node_info(req_id, match_block_node, current_time)
# 2. prepare cache
(
gpu_recv_block_ids,
gpu_extra_block_ids,
) = self._prepare_cache(
req_id,
input_ids,
block_size,
block_num,
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
)
# update matched token num
matched_block_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
unique_block_ids = gpu_extra_block_ids
dec_block_num = dec_token_num // block_size
left_input_ids = input_ids[matched_token_num_in_cpu_and_gpu:] # 没在前缀树中的token
gpu_build_path_block_ids = []
gpu_build_path_block_ids = gpu_extra_block_ids
leaf_node = self.build_path(
req_id,
current_time,
input_ids,
left_input_ids,
gpu_build_path_block_ids,
block_size,
match_block_node,
dec_block_num,
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
# 3. update metrics
if matched_block_num > 0:
self.metrics.hit_req_count += 1
self.metrics.calculate_hit_metrics(
req_id,
cpu_match_token_num,
gpu_match_token_num,
input_token_num,
)
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(
f"request_block_ids: request block for req_id {req_id}: common_block_ids "
+ f"{common_block_ids}, unique_block_ids {unique_block_ids}"
)
return common_block_ids, unique_block_ids, hit_info
except Exception as e:
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def release_block_ids_async(self, task):
"""
async release block ids
"""
return self.executor_pool.submit(self.release_block_ids, task)
def free_block_ids(self, need_block_num):
self.free_block_ids_async(need_block_num)
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
time.sleep(0.001)
def release_block_ids(self, task):
"""
release block ids
"""
with self.request_release_lock:
try:
req_id = task.request_id
leaf_node = self.req_leaf_map.pop(req_id)
if leaf_node in self.leaf_req_map:
self.leaf_req_map[leaf_node].remove(req_id)
if not (self.leaf_req_map[leaf_node]):
del self.leaf_req_map[leaf_node]
node = leaf_node
while node != self.radix_tree_root:
if req_id in node.req_id_set:
node.req_id_set.remove(req_id)
node.decrement_shared_count()
node = node.parent
if req_id in self.cache_info:
del self.cache_info[req_id]
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
if leaf_node == self.radix_tree_root:
self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id])
del self.unfilled_req_block_map[req_id]
return
if leaf_node in self.gpu_lru_leaf_set:
return
if leaf_node.shared_count == 0 and leaf_node.is_gpu_leaf_node and leaf_node.is_persistent is False:
self.gpu_lru_leaf_set.add(leaf_node)
heapq.heappush(self.gpu_lru_leaf_heap, leaf_node)
logger.info(
f"release_block_ids: req_id {req_id} has been finished, "
+ f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}"
)
return
except Exception as e:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def free_nodes_directly(self, node):
with self.request_release_lock:
try:
total_gpu_free_count = 0
while True:
if node in self.gpu_lru_leaf_heap:
self.gpu_lru_leaf_heap.remove(node)
self.gpu_lru_leaf_set.remove(node)
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
logger.info(f"free_nodes_directly: node {node}")
total_gpu_free_count += 1
cur_node = node
node = node.parent
if cur_node.hash_value in node.children:
del node.children[cur_node.hash_value]
if not node.children:
if node in self.gpu_lru_leaf_set:
continue
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
else:
break
else:
break
except Exception as e:
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
raise e
def _handle_free_gpu_node_without_cpu(self, node):
"""
GPU node eviction
"""
node.cache_status = CacheStatus.CPU
self.node_id_pool.append(node.node_id)
if node.node_id in self.node_map:
del self.node_map[node.node_id]
logger.info(f"free_block_ids_async: free node {node}")
self.recycle_gpu_blocks(node.reverved_dec_block_ids)
node.reverved_dec_block_ids = []
self.recycle_gpu_blocks(node.block_id)
def _handle_free_gpu_node_with_cpu(
self,
node,
hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
):
"""
GPU node eviction in hierarchical cache layers
"""
self.recycle_gpu_blocks(node.reverved_dec_block_ids)
node.reverved_dec_block_ids = []
need_recycle_gpu_block_ids.append(node.block_id)
hash_value_gpu_block_ids_map[node.input_hash_value].append(node.block_id)
hash_value_swap_node_ids_map[node.input_hash_value].append(node.node_id)
def _evict_cache_async(
self,
future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
):
"""
evict cache async (GPU --> CPU)
"""
if future is not None:
future.result()
transfer_task_id = str(uuid.uuid4())
swap_node_ids = []
need_transfer_task_gpu_block_ids = []
need_transfer_task_cpu_block_ids = []
cpu_block_ids = self.allocate_cpu_blocks(total_gpu_free_count)
for input_hash_value in hash_value_gpu_block_ids_map.keys():
need_transfer_task_gpu_block_ids.extend(reversed(hash_value_gpu_block_ids_map[input_hash_value]))
all_allocated_cpu_block_ids = []
for _ in reversed(hash_value_gpu_block_ids_map[input_hash_value]):
cpu_block_id_t = cpu_block_ids.pop(0)
all_allocated_cpu_block_ids.append(cpu_block_id_t)
need_transfer_task_cpu_block_ids.append(cpu_block_id_t)
swap_node_ids.extend(reversed(hash_value_swap_node_ids_map[input_hash_value]))
logger.info(
"free_block_ids_async: issue transfer task: "
+ f"transfer_task_id {transfer_task_id}: "
+ f"swap_node_ids {swap_node_ids} need_transfer_task_gpu_block_ids "
+ f"{need_transfer_task_gpu_block_ids}, need_transfer_task_cpu_block_ids "
+ f"{need_transfer_task_cpu_block_ids}, CacheStatus.SWAP2CPU"
)
self.issue_swap_task(
transfer_task_id,
swap_node_ids,
need_transfer_task_gpu_block_ids,
need_transfer_task_cpu_block_ids,
CacheStatus.SWAP2CPU,
True,
)
logger.info(
"free_block_ids_async: after free, " + f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
def free_block_ids_async(self, need_block_num):
"""
free block ids async
args
need_query_block_num: max number of gpu blocks to free
"""
with self.request_release_lock:
if self.gpu_free_task_future is not None:
if not self.gpu_free_task_future.done():
return
else:
self.gpu_free_task_future.result()
self.gpu_free_task_future = None
try:
need_recycle_gpu_block_ids = []
hash_value_input_ids_map = {}
hash_value_block_ids_map = defaultdict(list)
hash_value_depth_map = {}
hash_value_swap_node_ids_map = defaultdict(list)
hash_value_gpu_block_ids_map = defaultdict(list)
total_gpu_free_count = 0
while True:
if len(self.gpu_lru_leaf_heap) == 0:
break
if total_gpu_free_count >= need_block_num:
break
node = heapq.heappop(self.gpu_lru_leaf_heap)
self.gpu_lru_leaf_set.remove(node)
if (
not self.cache_config.enable_hierarchical_cache
or self.cache_config.num_cpu_blocks < need_block_num
):
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
total_gpu_free_count += 1
cur_node = node
node = node.parent
if cur_node.hash_value in node.children:
del node.children[cur_node.hash_value]
if not node.children:
if node in self.gpu_lru_leaf_set:
continue
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
else:
continue
else:
if node.shared_count == 0 and node.is_gpu_leaf_node:
node.cache_status = CacheStatus.SWAP2CPU
else:
continue
self._handle_free_gpu_node_with_cpu(
node,
hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
)
total_gpu_free_count += 1
node = node.parent
if node in self.gpu_lru_leaf_set:
continue
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
# swap cache to cpu
if hash_value_gpu_block_ids_map:
cpu_free_future = None
if total_gpu_free_count > len(self.cpu_free_block_list):
cpu_free_count = total_gpu_free_count
if cpu_free_count < need_block_num:
cpu_free_count = need_block_num
cpu_free_future = self.free_cpu_executor_pool.submit(self.free_cpu_block_ids, cpu_free_count)
self.gpu_free_task_future = self.free_gpu_executor_pool.submit(
self._evict_cache_async,
cpu_free_future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
)
else:
self.gpu_free_task_future = None
except Exception as e:
logger.error(f"free_block_ids_async: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def free_cpu_block_ids(self, need_block_num):
"""
Evict CPU blocks (at least need_block_num blocks)
Parameters:
- need_block_num: Number of CPU blocks required to evict
Returns:
- freed_block_num: Number of CPU blocks successfully evicted
"""
hash_value_block_ids_map = defaultdict(list)
total_cpu_free_count = 0
with self.request_release_lock:
while True:
if len(self.cpu_lru_leaf_heap) == 0:
break
if total_cpu_free_count >= need_block_num:
break
node = heapq.heappop(self.cpu_lru_leaf_heap)
self.cpu_lru_leaf_set.remove(node)
tmp_block_ids = []
if node.shared_count == 0 and node.cache_status == CacheStatus.CPU and node.is_cpu_leaf_node:
self.recycle_cpu_blocks(node.block_id)
hash_value_block_ids_map[node.input_hash_value].extend(reversed(tmp_block_ids))
logger.info(f"free_cpu_block_ids: free node {node}")
self.node_id_pool.append(node.node_id)
total_cpu_free_count += 1
if node.node_id in self.node_map:
del self.node_map[node.node_id]
cur_node = node
node = node.parent
if cur_node.hash_value in node.children:
del node.children[cur_node.hash_value]
if not node.children:
if node in self.cpu_lru_leaf_set:
continue
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU
):
heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node)
logger.info(
"free_cpu_block_ids: after free, " + f"len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
)
return total_cpu_free_count
def cal_block_hash(self, block):
"""
calculate hash value of a block
"""
return hash(tuple(block))
def match_block(self, req_id, input_ids, block_size):
"""
Args:
req_id: Task request ID
input_ids: Input token IDs
block_size: Size of each block
Returns:
match_gpu_block_ids: List of matched GPU block IDs
match_cpu_block_ids: List of matched CPU block IDs
swap_node_ids: List of node IDs requiring swap operations
match_block_node: Last matched node in the path
gpu_match_token_num: Number of tokens matched in GPU blocks
cpu_match_token_num: Number of tokens matched in CPU blocks
"""
total_token_num = len(input_ids)
current_match_node = self.radix_tree_root # 从根节点开始搜
match_gpu_block_ids = []
match_cpu_block_ids = []
match_node_ids = []
match_token_num = 0
cpu_match_token_num = 0
gpu_match_token_num = 0
swap_node_ids = []
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
with self.cache_status_lock:
while match_token_num < total_token_num:
token_block = input_ids[match_token_num : match_token_num + block_size]
token_num = len(token_block)
if token_num != block_size:
break
hash_value = self.cal_block_hash(token_block)
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
match_node_ids.append(child.node_id)
if child in self.gpu_lru_leaf_set:
self.gpu_lru_leaf_set.remove(child)
self.gpu_lru_leaf_heap.remove(child)
has_modified_gpu_lru_leaf_heap = True
elif child in self.cpu_lru_leaf_set:
self.cpu_lru_leaf_set.remove(child)
self.cpu_lru_leaf_heap.remove(child)
has_modified_cpu_lru_leaf_heap = True
if child.has_in_gpu:
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
else:
if child.cache_status == CacheStatus.SWAP2CPU:
logger.info(
f"match_block: req_id {req_id} matched node"
+ f" {child.node_id} which is being SWAP2CPU"
)
child.cache_status = CacheStatus.GPU
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
elif child.cache_status == CacheStatus.CPU:
child.cache_status = CacheStatus.SWAP2GPU
match_cpu_block_ids.append(child.block_id)
cpu_match_token_num += block_size
swap_node_ids.append(child.node_id)
match_token_num = match_token_num + block_size
current_match_node = child
else:
break
if has_modified_gpu_lru_leaf_heap:
heapq.heapify(self.gpu_lru_leaf_heap)
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)
logger.info(f"match_block: req_id {req_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
current_match_node,
gpu_match_token_num,
cpu_match_token_num,
)
def _update_matched_node_info(self, req_id, last_node, current_time):
"""
Update the shared count and last used time of the matched nodes
"""
node = last_node
while node != self.radix_tree_root:
node.increment_shared_count()
node.last_used_time = current_time
node.req_id_set.add(req_id)
node = node.parent
def build_path(
self,
req_id,
current_time,
input_ids,
left_input_ids,
gpu_block_ids,
block_size,
last_node,
reverved_dec_block_num,
):
"""
Build path for blocks beyond the common prefix
Parameters:
- req_id: Request ID of the task
- left_input_ids: Remaining input tokens not found in the prefix tree
- gpu_block_ids: List of available GPU block IDs for new node allocation
- block_size: Token capacity per block
- last_node: Last successfully matched node
- reserved_dec_block_num: Number of blocks reserved for decoding
Returns:
- leaf_node: The constructed leaf node
"""
gpu_block_ids = gpu_block_ids.copy()
node = last_node
reverved_dec_block_ids = []
input_hash_value = self.cal_block_hash(input_ids)
token_num = len(left_input_ids)
if token_num == 0:
for i in range(reverved_dec_block_num):
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
return last_node
node = last_node
unique_node_ids = []
new_last_node = last_node
has_unfilled_block = False
for i in range(0, token_num, block_size):
current_block = left_input_ids[i : i + block_size]
current_block_size = len(current_block) # 最后一个block可能没填满
if current_block_size != block_size:
has_unfilled_block = True
else:
hash_value = self.cal_block_hash(current_block)
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
new_last_node = BlockNode(
node_id,
input_ids,
input_hash_value,
node.depth + 1,
allocated_block_id,
current_block_size,
hash_value,
current_time,
parent=node,
shared_count=1,
reverved_dec_block_ids=[],
)
new_last_node.req_id_set.add(req_id)
self.node_map[node_id] = new_last_node
node.children[hash_value] = new_last_node
node = new_last_node
if has_unfilled_block is True:
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
for i in range(reverved_dec_block_num):
reverved_dec_block_ids.append(gpu_block_ids.pop(0))
if new_last_node == self.radix_tree_root:
self.unfilled_req_block_map[req_id] = reverved_dec_block_ids
else:
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {req_id}")
return new_last_node
def _handle_swap_result(self, swap_node_id, task_gpu_block_id, task_cpu_block_id, event_type):
"""
handle swap resuha
"""
if swap_node_id is None:
return
with self.cache_status_lock:
if event_type.value == CacheStatus.SWAP2CPU.value:
gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id
node = self.node_map[swap_node_id]
if node.cache_status.value == CacheStatus.GPU.value:
logger.info(
f"recv_data_transfer_result: node {node.node_id} "
+ f"has been reused when SWAP2CPU, recycle cpu block id {cpu_block_id}"
)
self.recycle_cpu_blocks(cpu_block_id)
else:
node.cache_status = CacheStatus.CPU
node.block_id = cpu_block_id
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU
):
if node not in self.cpu_lru_leaf_set:
heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node)
self.recycle_gpu_blocks(gpu_block_id)
logger.info(f"recv_data_transfer_result: after SWAP2CPU, node {node}")
elif event_type.value == CacheStatus.SWAP2GPU.value:
gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id
node = self.node_map[swap_node_id]
node.cache_status = CacheStatus.GPU
node.block_id = gpu_block_id
self.recycle_cpu_blocks(cpu_block_id)
logger.info(f"recv_data_transfer_result: after SWAP2GPU, node {node}")
else:
logger.warning(
f"recv_data_transfer_result: Get unexpected event type {event_type}"
+ ", only SWAP2CPU and SWAP2GPU supported"
)
def recv_data_transfer_result(self):
"""
recv data transfer result
"""
while True:
try:
data = self.cache_task_queue.get_transfer_done_signal()
if data is None:
time.sleep(0.001)
continue
(
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
) = data
length = len(task_gpu_block_id)
for i in range(length):
self._handle_swap_result(
swap_node_ids[i],
task_gpu_block_id[i],
task_cpu_block_id[i],
event_type,
)
if transfer_task_id in self.task_swapping_event:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e