[Bug fix] Fix prefix cache in V1 (#3715)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* [Bug fix] Fix prefix cache in V1

* fix code style
This commit is contained in:
chenjian
2025-08-31 21:29:33 +08:00
committed by GitHub
parent bed09ae8f8
commit 465065cd19
2 changed files with 76 additions and 9 deletions

View File

@@ -510,7 +510,7 @@ class PrefixCacheManager:
self.metrics.req_count += 1 self.metrics.req_count += 1
input_ids = task.prompt_token_ids input_ids = task.prompt_token_ids
req_id = task.request_id req_id = task.request_id
logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}") logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids) input_token_num = len(input_ids)
common_block_ids = [] common_block_ids = []
# 1. match block # 1. match block
@@ -542,7 +542,9 @@ class PrefixCacheManager:
cpu_recv_block_ids=[], cpu_recv_block_ids=[],
) )
else: else:
raise Exception("Not enough GPU memory to allocate cache for matched CPU Cache") raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)
# record request cache info # record request cache info
self.cache_info[req_id] = (match_block_node, input_ids) self.cache_info[req_id] = (match_block_node, input_ids)
@@ -564,11 +566,14 @@ class PrefixCacheManager:
if self.metrics.req_count % 10000 == 0: if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics() self.metrics.reset_metrics()
logger.info( logger.info(
f"request_block_ids: request block for req_id {req_id}: common_block_ids {common_block_ids}" f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
) )
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
return common_block_ids, matched_token_num, hit_info return common_block_ids, matched_token_num, hit_info
except Exception as e: except Exception as e:
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}") logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e raise e
def request_block_ids(self, task, block_size, dec_token_num, *args): def request_block_ids(self, task, block_size, dec_token_num, *args):
@@ -725,6 +730,41 @@ class PrefixCacheManager:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}") logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e raise e
def free_nodes_directly(self, node):
with self.request_release_lock:
try:
total_gpu_free_count = 0
while True:
if node in self.gpu_lru_leaf_heap:
self.gpu_lru_leaf_heap.remove(node)
self.gpu_lru_leaf_set.remove(node)
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
logger.info(f"free_nodes_directly: node {node}")
total_gpu_free_count += 1
cur_node = node
node = node.parent
if cur_node.hash_value in node.children:
del node.children[cur_node.hash_value]
if not node.children:
if node in self.gpu_lru_leaf_set:
continue
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
else:
break
else:
break
except Exception as e:
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
raise e
def _handle_free_gpu_node_without_cpu(self, node): def _handle_free_gpu_node_without_cpu(self, node):
""" """
GPU node eviction GPU node eviction
@@ -1068,6 +1108,15 @@ class PrefixCacheManager:
node.req_id_set.add(req_id) node.req_id_set.add(req_id)
node = node.parent node = node.parent
def decrease_request_share_count(self, req_id):
"""
Decrease node shared count
"""
node, input_ids = self.cache_info[req_id]
while node != self.radix_tree_root:
node.decrement_shared_count()
node = node.parent
def build_path( def build_path(
self, self,
req_id, req_id,

View File

@@ -118,8 +118,8 @@ class ResourceManagerV1(ResourceManager):
preempted_req = self.running.pop() preempted_req = self.running.pop()
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.prefill_block_num = 0
self._free_blocks(preempted_req) self._free_blocks(preempted_req)
preempted_req.prefill_block_num = None
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -336,6 +336,7 @@ class ResourceManagerV1(ResourceManager):
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
success = self.get_prefix_cached_blocks(request) success = self.get_prefix_cached_blocks(request)
if not success: if not success:
self._free_blocks(request)
break break
num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_tokens = self._get_num_new_tokens(request, token_budget)
@@ -358,16 +359,24 @@ class ResourceManagerV1(ResourceManager):
self.stop_flags[allocated_position] = False self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position self.req_dict[request.request_id] = allocated_position
else: else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
break break
elif request.status == RequestStatus.PREEMPTED: elif request.status == RequestStatus.PREEMPTED:
request.need_prefill_tokens = ( request.need_prefill_tokens = (
request.num_total_tokens request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if self.config.cache_config.enable_prefix_caching:
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
break
num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens) num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill # Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) if not request.get("skip_allocate", False):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
self.waiting.popleft() self.waiting.popleft()
self.running.append(request) self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
@@ -375,6 +384,8 @@ class ResourceManagerV1(ResourceManager):
request.num_computed_tokens += num_new_tokens request.num_computed_tokens += num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
else: else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
break break
else: else:
llm_logger.error("Unknown request status type") llm_logger.error("Unknown request status type")
@@ -430,7 +441,7 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
if matched_token_num == request.prompt_token_ids_len: if matched_token_num == request.prompt_token_ids_len:
request.num_computed_tokens = matched_token_num - 1 request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True request.skip_allocate = True
else: else:
request.num_computed_tokens = matched_token_num request.num_computed_tokens = matched_token_num
@@ -448,8 +459,15 @@ class ResourceManagerV1(ResourceManager):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
# TODO(chengyanfu): support cache ouput blocks for prefix caching # TODO(chengyanfu): support cache ouput blocks for prefix caching
self.cache_manager.release_block_ids_async(request) if request.get("prefill_block_num", None) is None:
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :]) leaf_node = self.cache_manager.req_leaf_map[request.request_id]
self.cache_manager.decrease_request_share_count(request.request_id)
self.cache_manager.free_nodes_directly(leaf_node)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
else:
self.cache_manager.release_block_ids_async(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
else: else:
self.cache_manager.recycle_gpu_blocks(request.block_tables) self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = [] request.block_tables = []