[Optimize] optimize prefix cache in release22 (#3889)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* optimize prefix cache in release22

* optimize prefix cache in release22

* fix worker

* fix

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-06 09:52:01 +08:00
committed by GitHub
parent 41cd3e24c9
commit 8d77c1cb51
4 changed files with 44 additions and 44 deletions

View File

@@ -257,7 +257,12 @@ class PrefixCacheManager:
Check if num_blocks gpu blocks can be allocated. Check if num_blocks gpu blocks can be allocated.
""" """
if len(self.gpu_free_block_list) < num_blocks: if len(self.gpu_free_block_list) < num_blocks:
return False if self.cache_config.enable_prefix_caching:
self.free_block_ids(num_blocks)
if len(self.gpu_free_block_list) < num_blocks:
return False
else:
return True
else: else:
return True return True
@@ -448,7 +453,7 @@ class PrefixCacheManager:
""" """
return (input_token_num + block_size - 1) // block_size return (input_token_num + block_size - 1) // block_size
def update_cache_blocks(self, task, block_size): def update_cache_blocks(self, task, block_size, num_computed_tokens):
""" """
update cache blocks for a task. update cache blocks for a task.
# TODO(chengyanfu): support async update # TODO(chengyanfu): support async update
@@ -459,12 +464,15 @@ class PrefixCacheManager:
""" """
try: try:
req_id = task.request_id req_id = task.request_id
num_cached_tokens = task.num_cached_tokens
block_tables = task.block_tables block_tables = task.block_tables
last_node, input_ids = self.cache_info[req_id] last_node, num_cached_tokens = self.cache_info[req_id]
left_input_ids = input_ids[num_cached_tokens:] input_ids = task.prompt_token_ids + task.output_token_ids
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :] gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
with self.request_release_lock: with self.request_release_lock:
current_time = time.time() current_time = time.time()
@@ -480,7 +488,8 @@ class PrefixCacheManager:
) )
self.req_leaf_map[req_id] = leaf_node self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id) self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = (leaf_node, input_ids) self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
task.cached_block_num = can_cache_computed_tokens // block_size
except Exception as e: except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}") logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e raise e
@@ -508,7 +517,7 @@ class PrefixCacheManager:
hit_info["gpu_cache_blocks"] = 0 hit_info["gpu_cache_blocks"] = 0
hit_info["cpu_cache_blocks"] = 0 hit_info["cpu_cache_blocks"] = 0
self.metrics.req_count += 1 self.metrics.req_count += 1
input_ids = task.prompt_token_ids input_ids = task.prompt_token_ids + task.output_token_ids
req_id = task.request_id req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}") logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids) input_token_num = len(input_ids)
@@ -546,9 +555,6 @@ class PrefixCacheManager:
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache" "request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
) )
# record request cache info
self.cache_info[req_id] = (match_block_node, input_ids)
# 3. update metrics # 3. update metrics
matched_token_num = gpu_match_token_num + cpu_match_token_num matched_token_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
@@ -571,6 +577,9 @@ class PrefixCacheManager:
# set leaf node temporarily, then update it in update_cache_blocks # set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id) self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, matched_token_num)
task.cached_block_num = matched_token_num // block_size
return common_block_ids, matched_token_num, hit_info return common_block_ids, matched_token_num, hit_info
except Exception as e: except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}") logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
@@ -687,6 +696,11 @@ class PrefixCacheManager:
""" """
return self.executor_pool.submit(self.release_block_ids, task) return self.executor_pool.submit(self.release_block_ids, task)
def free_block_ids(self, need_block_num):
self.free_block_ids_async(need_block_num)
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
time.sleep(0.001)
def release_block_ids(self, task): def release_block_ids(self, task):
""" """
release block ids release block ids
@@ -1108,15 +1122,6 @@ class PrefixCacheManager:
node.req_id_set.add(req_id) node.req_id_set.add(req_id)
node = node.parent node = node.parent
def decrease_request_share_count(self, req_id):
"""
Decrease node shared count
"""
node, input_ids = self.cache_info[req_id]
while node != self.radix_tree_root:
node.decrement_shared_count()
node = node.parent
def build_path( def build_path(
self, self,
req_id, req_id,

View File

@@ -527,9 +527,8 @@ class EngineSevice:
self.cfg.max_prefill_batch, self.cfg.max_prefill_batch,
) )
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests( tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(), available_blocks=self.cfg.cache_config.max_block_num_per_seq,
block_size=self.cfg.cache_config.block_size, block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_model_len, max_num_batched_tokens=self.cfg.max_model_len,

View File

@@ -84,7 +84,6 @@ class ResourceManagerV1(ResourceManager):
return len(request.block_tables) * self.config.cache_config.block_size return len(request.block_tables) * self.config.cache_config.block_size
def get_new_block_nums(self, request: Request, num_new_tokens: int): def get_new_block_nums(self, request: Request, num_new_tokens: int):
self.check_and_free_block_tables()
return ( return (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables) ) // self.config.cache_config.block_size - len(request.block_tables)
@@ -119,7 +118,7 @@ class ResourceManagerV1(ResourceManager):
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
self._free_blocks(preempted_req) self._free_blocks(preempted_req)
preempted_req.prefill_block_num = None preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -282,14 +281,6 @@ class ResourceManagerV1(ResourceManager):
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1 request.num_computed_tokens = request.num_total_tokens - 1
else: # prefill finished
if (
self.config.cache_config.enable_prefix_caching
and request.get("prefill_block_num", None) is None
):
# update prefill cache blocks for prefix caching
request.prefill_block_num = len(request.block_tables)
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
if ( if (
self.allocated_slots(request) - request.num_total_tokens self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold <= self.config.cache_config.prealloc_dec_block_slot_num_threshold
@@ -339,6 +330,10 @@ class ResourceManagerV1(ResourceManager):
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
req_index += 1 req_index += 1
# schedule the WAITING requests. # schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
@@ -371,6 +366,10 @@ class ResourceManagerV1(ResourceManager):
request.schedule_start_time = time.time() request.schedule_start_time = time.time()
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1) main_process_metrics.num_requests_running.inc(1)
@@ -403,6 +402,10 @@ class ResourceManagerV1(ResourceManager):
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1) main_process_metrics.num_requests_running.inc(1)
@@ -447,7 +450,7 @@ class ResourceManagerV1(ResourceManager):
matched_block_num = len(common_block_ids) matched_block_num = len(common_block_ids)
no_cache_block_num = self.cache_manager.get_required_block_num( no_cache_block_num = self.cache_manager.get_required_block_num(
request.prompt_token_ids_len - matched_token_num, request.need_prefill_tokens - matched_token_num,
self.config.cache_config.block_size, self.config.cache_config.block_size,
) )
@@ -463,7 +466,7 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num) main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
if matched_token_num == request.prompt_token_ids_len: if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True request.skip_allocate = True
else: else:
@@ -481,16 +484,8 @@ class ResourceManagerV1(ResourceManager):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
# TODO(chengyanfu): support cache ouput blocks for prefix caching self.cache_manager.release_block_ids(request)
if request.get("prefill_block_num", None) is None: self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
self.cache_manager.decrease_request_share_count(request.request_id)
self.cache_manager.free_nodes_directly(leaf_node)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
else:
self.cache_manager.release_block_ids_async(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
else: else:
self.cache_manager.recycle_gpu_blocks(request.block_tables) self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = [] request.block_tables = []

View File

@@ -1347,6 +1347,7 @@ class GPUModelRunner(ModelRunnerBase):
if ( if (
not self.cache_config.enable_chunked_prefill not self.cache_config.enable_chunked_prefill
or self.guided_backend is None or self.guided_backend is None
or model_forward_batch is None
or envs.ENABLE_V1_KVCACHE_SCHEDULER or envs.ENABLE_V1_KVCACHE_SCHEDULER
): ):
return skip_idx_list return skip_idx_list
@@ -1549,7 +1550,7 @@ class GPUModelRunner(ModelRunnerBase):
""" """
Add cache for guided decoding. Add cache for guided decoding.
""" """
if self.guided_backend is None: if self.guided_backend is None or model_forward_batch is None:
return return
for request in model_forward_batch: for request in model_forward_batch: