[Optimize] optimize prefix cache in develop (#3890)

* optimize prefix cache in release22

* fix

* fix

* fix

* add ci for v1

* add unit test

---------

Co-authored-by: xiegegege <46314656+xiegegege@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-12 10:15:59 +08:00
committed by GitHub
parent 4859f40b20
commit 37f1632732
5 changed files with 271 additions and 43 deletions

View File

@@ -98,7 +98,6 @@ class ResourceManagerV1(ResourceManager):
return len(request.block_tables) * self.config.cache_config.block_size
def get_new_block_nums(self, request: Request, num_new_tokens: int):
self.check_and_free_block_tables()
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables)
@@ -137,7 +136,7 @@ class ResourceManagerV1(ResourceManager):
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self._free_blocks(preempted_req)
preempted_req.prefill_block_num = None
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -300,14 +299,6 @@ class ResourceManagerV1(ResourceManager):
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
else: # prefill finished
if (
self.config.cache_config.enable_prefix_caching
and request.get("prefill_block_num", None) is None
):
# update prefill cache blocks for prefix caching
request.prefill_block_num = len(request.block_tables)
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
if (
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
@@ -357,6 +348,10 @@ class ResourceManagerV1(ResourceManager):
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
req_index += 1
# schedule the WAITING requests.
if not preempted_reqs:
@@ -371,6 +366,15 @@ class ResourceManagerV1(ResourceManager):
if request.status == RequestStatus.WAITING:
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if (
self.config.cache_config.enable_hierarchical_cache
and self.cache_manager.num_cpu_blocks > 0
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
@@ -389,6 +393,10 @@ class ResourceManagerV1(ResourceManager):
request.schedule_start_time = time.time()
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
@@ -406,6 +414,15 @@ class ResourceManagerV1(ResourceManager):
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if self.config.cache_config.enable_prefix_caching:
if (
self.config.cache_config.enable_hierarchical_cache
and self.cache_manager.num_cpu_blocks > 0
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
@@ -421,6 +438,10 @@ class ResourceManagerV1(ResourceManager):
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
@@ -516,7 +537,7 @@ class ResourceManagerV1(ResourceManager):
matched_block_num = len(common_block_ids)
no_cache_block_num = self.cache_manager.get_required_block_num(
request.prompt_token_ids_len - matched_token_num,
request.need_prefill_tokens - matched_token_num,
self.config.cache_config.block_size,
)
@@ -532,7 +553,7 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
if matched_token_num == request.prompt_token_ids_len:
if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True
else:
@@ -550,16 +571,8 @@ class ResourceManagerV1(ResourceManager):
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
# TODO(chengyanfu): support cache output blocks for prefix caching
if request.get("prefill_block_num", None) is None:
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
self.cache_manager.decrease_request_share_count(request.request_id)
self.cache_manager.free_nodes_directly(leaf_node)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
else:
self.cache_manager.release_block_ids_async(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
else:
self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = []