From 22cab724e8db8842907f5ab889817eb1a1aab207 Mon Sep 17 00:00:00 2001 From: kevin Date: Thu, 31 Jul 2025 19:29:19 +0800 Subject: [PATCH] [Feature] block scheduler v1 support prefix caching (#3061) * block scheduler v1 support prefix cache * update code * update code * fix code bug * add timeout time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- .../cache_manager/prefix_cache_manager.py | 137 +++++++++++++++++- fastdeploy/engine/engine.py | 2 + .../engine/sched/resource_manager_v1.py | 64 +++++++- scripts/run_ci.sh | 2 +- 4 files changed, 198 insertions(+), 7 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index dd191c87f..0ac34ad6a 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -93,6 +93,7 @@ class PrefixCacheManager: self.req_leaf_map = {} # {request_id: leaf node} self.leaf_req_map = defaultdict(set) self.unfilled_req_block_map = defaultdict(list) + self.cache_info = {} self.executor_pool = ThreadPoolExecutor(max_workers=1) self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1) @@ -425,6 +426,135 @@ class PrefixCacheManager: return gpu_recv_block_ids, gpu_extra_block_ids + def get_required_block_num(self, input_token_num, block_size): + """ + get required block num by input token num and block size + """ + return (input_token_num + block_size - 1) // block_size + + def update_cache_blocks(self, task, block_size): + """ + update cache blocks for a task. + # TODO(chengyanfu): support async update + + Parameters: + - task: Task + - block_size: Size per block (in tokens) + """ + try: + req_id = task.request_id + num_cached_tokens = task.num_cached_tokens + block_tables = task.block_tables + + last_node, input_ids = self.cache_info[req_id] + left_input_ids = input_ids[num_cached_tokens:] + gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :] + + with self.request_release_lock: + current_time = time.time() + leaf_node = self.build_path( + req_id=req_id, + current_time=current_time, + input_ids=input_ids, + left_input_ids=left_input_ids, + gpu_block_ids=gpu_extra_block_ids, + block_size=block_size, + last_node=last_node, + reverved_dec_block_num=0, + ) + self.req_leaf_map[req_id] = leaf_node + self.leaf_req_map[leaf_node].add(req_id) + self.cache_info[req_id] = (leaf_node, input_ids) + except Exception as e: + logger.error(f"update_cache_blocks, error: {type(e)} {e}") + raise e + + def request_match_blocks(self, task, block_size, *args): + """ + get match blocks info for a task. + This is a synchronous interface. If CPU-to-GPU data transfer occurs, + it will block until synchronization completes. + Callers requiring asynchronous behavior should invoke this via a thread pool. + + Note: This function may allocate GPU blocks for matched CPU Cache + + Parameters: + - task: Task dictionary + - block_size: Size per block (in tokens) + + Returns: + - common_block_ids: List of matched shared blocks + - unique_block_ids: List of exclusively allocated blocks + """ + with self.request_release_lock: + try: + hit_info = {} + hit_info["gpu_cache_blocks"] = 0 + hit_info["cpu_cache_blocks"] = 0 + self.metrics.req_count += 1 + input_ids = task.prompt_token_ids + req_id = task.request_id + logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}") + input_token_num = len(input_ids) + common_block_ids = [] + # 1. match block + ( + match_gpu_block_ids, + match_cpu_block_ids, + swap_node_ids, + match_block_node, + gpu_match_token_num, + cpu_match_token_num, + ) = self.match_block(req_id, input_ids, block_size) + + # update matched node info + self._update_matched_node_info(req_id, match_block_node, current_time=time.time()) + + # 2. prepare cache + # allocate gpu cache for matched cpu blocks + gpu_recv_block_ids = [] + match_cpu_blocks_num = len(match_cpu_block_ids) + if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num): + if match_cpu_blocks_num > 0: + gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num) + if len(gpu_recv_block_ids) > 0: + self._prepare_cpu_cache( + req_id=req_id, + swap_node_ids=swap_node_ids, + gpu_recv_block_ids=gpu_recv_block_ids, + match_cpu_block_ids=match_cpu_block_ids, + cpu_recv_block_ids=[], + ) + else: + raise Exception("Not enough GPU memory to allocate cache for matched CPU Cache") + + # record request cache info + self.cache_info[req_id] = (match_block_node, input_ids) + + # 3. update metrics + matched_token_num = gpu_match_token_num + cpu_match_token_num + common_block_ids = match_gpu_block_ids + gpu_recv_block_ids + if matched_token_num > 0: + self.metrics.hit_req_count += 1 + self.metrics.calculate_hit_metrics( + req_id, + cpu_match_token_num, + gpu_match_token_num, + input_token_num, + ) + hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size + hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size + self.metrics._update_history_hit_metrics() + if self.metrics.req_count % 10000 == 0: + self.metrics.reset_metrics() + logger.info( + f"request_block_ids: request block for req_id {req_id}: common_block_ids {common_block_ids}" + ) + return common_block_ids, matched_token_num, hit_info + except Exception as e: + logger.error(f"request_block_ids: error: {type(e)} {e}") + raise e + def request_block_ids(self, task, block_size, dec_token_num, *args): """ Allocate blocks for a task. @@ -463,12 +593,10 @@ class PrefixCacheManager: cpu_match_token_num, ) = self.match_block(req_id, input_ids, block_size) match_gpu_blocks_num = len(match_gpu_block_ids) - match_cpu_blocks_num = len(match_cpu_block_ids) - matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num # check enough gpu memory to allocate cache block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size - self._check_validity(req_id, matched_block_num, block_num) + self._check_validity(req_id, match_gpu_blocks_num, block_num) # update matched node info current_time = time.time() self._update_matched_node_info(req_id, match_block_node, current_time) @@ -557,6 +685,9 @@ class PrefixCacheManager: node.decrement_shared_count() node = node.parent + if req_id in self.cache_info: + del self.cache_info[req_id] + logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}") if leaf_node == self.radix_tree_root: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 4fd075d4b..999d6b056 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -373,6 +373,8 @@ class LLMEngine: int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) + + self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 051e985db..764e71de7 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -80,6 +80,7 @@ class ResourceManagerV1(ResourceManager): return len(request.block_tables) * self.config.cache_config.block_size def get_new_block_nums(self, request: Request, num_new_tokens: int): + self.check_and_free_block_tables() return ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) @@ -103,6 +104,7 @@ class ResourceManagerV1(ResourceManager): preempted_req = self.running.pop() preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 + preempted_req.prefill_block_num = 0 self._free_blocks(preempted_req) self.waiting.appendleft(preempted_req) preempted_reqs.append(preempted_req) @@ -212,6 +214,14 @@ class ResourceManagerV1(ResourceManager): if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 + else: # prefill finished + if ( + self.config.cache_config.enable_prefix_caching + and request.get("prefill_block_num", None) is None + ): + # update prefill cache blocks for prefix caching + request.prefill_block_num = len(request.block_tables) + self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size) if ( self.allocated_slots(request) - request.num_total_tokens <= self.config.cache_config.prealloc_dec_block_slot_num_threshold @@ -271,11 +281,18 @@ class ResourceManagerV1(ResourceManager): break request = self.waiting[0] if request.status == RequestStatus.WAITING: + # Enable prefix caching + if self.config.cache_config.enable_prefix_caching: + success = self.get_prefix_cached_blocks(request) + if not success: + break + num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + if not request.get("skip_allocate", False): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) self.waiting.popleft() self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) @@ -292,7 +309,9 @@ class ResourceManagerV1(ResourceManager): else: break elif request.status == RequestStatus.PREEMPTED: - request.need_prefill_tokens = request.num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct + request.need_prefill_tokens = ( + request.num_total_tokens + ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill @@ -327,12 +346,51 @@ class ResourceManagerV1(ResourceManager): break return self.real_bsz + def get_prefix_cached_blocks(self, request: Request): + """ + set prefix cached information for the given request + """ + try: + cache_prepare_time = time.time() + (common_block_ids, matched_token_num, hit_info) = self.cache_manager.request_match_blocks( + request, self.config.cache_config.block_size + ) + + matched_block_num = len(common_block_ids) + no_cache_block_num = self.cache_manager.get_required_block_num( + request.prompt_token_ids_len - matched_token_num, + self.config.cache_config.block_size, + ) + + request.num_cached_tokens = matched_token_num + request.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.config.cache_config.block_size + request.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.config.cache_config.block_size + request.cache_info = (matched_block_num, no_cache_block_num) + request.block_tables = common_block_ids + request.skip_allocate = False + + if matched_token_num == request.prompt_token_ids_len: + request.num_computed_tokens = matched_token_num - 1 + request.skip_allocate = True + else: + request.num_computed_tokens = matched_token_num + request.cache_prepare_time = time.time() - cache_prepare_time + return True + except Exception as e: + llm_logger.error(f"prefix match blocks error: {e}, waiting reschedule...") + return False + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request def _free_blocks(self, request: Request): - self.cache_manager.recycle_gpu_blocks(request.block_tables) + if self.config.cache_config.enable_prefix_caching: + # TODO(chengyanfu): support cache ouput blocks for prefix caching + self.cache_manager.release_block_ids_async(request) + self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :]) + else: + self.cache_manager.recycle_gpu_blocks(request.block_tables) request.block_tables = [] def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): diff --git a/scripts/run_ci.sh b/scripts/run_ci.sh index 7d2f9033d..91ef179b7 100644 --- a/scripts/run_ci.sh +++ b/scripts/run_ci.sh @@ -27,7 +27,7 @@ for subdir in "$run_path"*/; do echo "------------------------------------------------------------" set +e - timeout 360 python -m pytest --disable-warnings -sv "$file" + timeout 600 python -m pytest --disable-warnings -sv "$file" exit_code=$? set -e