[optimize] Optimize prefix caching in v1 release/2.1 (#3823)

* [optimize] Optimize prefix caching in v1

* [optimize] Optimize prefix caching in v1

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-04 19:25:02 +08:00
committed by GitHub
parent c2f5c99b1e
commit ffec66097c
3 changed files with 49 additions and 41 deletions

View File

@@ -256,7 +256,11 @@ class PrefixCacheManager:
Check if num_blocks gpu blocks can be allocated.
"""
if len(self.gpu_free_block_list) < num_blocks:
return False
self.free_block_ids(num_blocks)
if len(self.gpu_free_block_list) < num_blocks:
return False
else:
return True
else:
return True
@@ -447,7 +451,7 @@ class PrefixCacheManager:
"""
return (input_token_num + block_size - 1) // block_size
def update_cache_blocks(self, task, block_size):
def update_cache_blocks(self, task, block_size, num_computed_tokens):
"""
update cache blocks for a task.
# TODO(chengyanfu): support async update
@@ -458,12 +462,15 @@ class PrefixCacheManager:
"""
try:
req_id = task.request_id
num_cached_tokens = task.num_cached_tokens
block_tables = task.block_tables
last_node, input_ids = self.cache_info[req_id]
left_input_ids = input_ids[num_cached_tokens:]
last_node, num_cached_tokens = self.cache_info[req_id]
input_ids = task.prompt_token_ids + task.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
with self.request_release_lock:
current_time = time.time()
@@ -479,7 +486,8 @@ class PrefixCacheManager:
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = (leaf_node, input_ids)
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
task.cached_block_num = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}")
raise e
@@ -541,10 +549,9 @@ class PrefixCacheManager:
cpu_recv_block_ids=[],
)
else:
raise Exception("request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache")
# record request cache info
self.cache_info[req_id] = (match_block_node, input_ids)
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)
# 3. update metrics
matched_token_num = gpu_match_token_num + cpu_match_token_num
@@ -568,6 +575,9 @@ class PrefixCacheManager:
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, matched_token_num)
task.cached_block_num = matched_token_num // block_size
return common_block_ids, matched_token_num, hit_info
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
@@ -726,6 +736,7 @@ class PrefixCacheManager:
except Exception as e:
logger.error(f"release_block_ids: error: {type(e)} {e}")
raise e
def free_nodes_directly(self, node):
"""
Recycle nodes by a query directly.
@@ -848,6 +859,11 @@ class PrefixCacheManager:
"free_block_ids_async: after free, " + f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
def free_block_ids(self, need_block_num):
self.free_block_ids_async(need_block_num)
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
time.sleep(0.001)
def free_block_ids_async(self, need_block_num):
"""
free block ids async
@@ -1106,15 +1122,6 @@ class PrefixCacheManager:
node.last_used_time = current_time
node.req_id_set.add(req_id)
node = node.parent
def decrease_request_share_count(self, req_id):
"""
Decrease node shared count
"""
node, input_ids = self.cache_info[req_id]
while node != self.radix_tree_root:
node.decrement_shared_count()
node = node.parent
def build_path(
self,