diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 83f059ea1..a0b110bde 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -257,7 +257,12 @@ class PrefixCacheManager: Check if num_blocks gpu blocks can be allocated. """ if len(self.gpu_free_block_list) < num_blocks: - return False + if self.cache_config.enable_prefix_caching: + self.free_block_ids(num_blocks) + if len(self.gpu_free_block_list) < num_blocks: + return False + else: + return True else: return True @@ -448,7 +453,7 @@ class PrefixCacheManager: """ return (input_token_num + block_size - 1) // block_size - def update_cache_blocks(self, task, block_size): + def update_cache_blocks(self, task, block_size, num_computed_tokens): """ update cache blocks for a task. # TODO(chengyanfu): support async update @@ -459,12 +464,19 @@ class PrefixCacheManager: """ try: req_id = task.request_id - num_cached_tokens = task.num_cached_tokens block_tables = task.block_tables - last_node, input_ids = self.cache_info[req_id] - left_input_ids = input_ids[num_cached_tokens:] + last_node, num_cached_tokens = self.cache_info[req_id] + if isinstance(task.prompt_token_ids, np.ndarray): + prompt_token_ids = task.prompt_token_ids.tolist() + else: + prompt_token_ids = task.prompt_token_ids + input_ids = prompt_token_ids + task.output_token_ids + can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size + left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens] gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :] + if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later + self.leaf_req_map[last_node].remove(req_id) with self.request_release_lock: current_time = time.time() @@ -480,7 +492,8 @@ class PrefixCacheManager: ) self.req_leaf_map[req_id] = leaf_node self.leaf_req_map[leaf_node].add(req_id) - self.cache_info[req_id] = (leaf_node, input_ids) + self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens) + task.cached_block_num = can_cache_computed_tokens // block_size except Exception as e: logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}") raise e @@ -508,7 +521,11 @@ class PrefixCacheManager: hit_info["gpu_cache_blocks"] = 0 hit_info["cpu_cache_blocks"] = 0 self.metrics.req_count += 1 - input_ids = task.prompt_token_ids + if isinstance(task.prompt_token_ids, np.ndarray): + prompt_token_ids = task.prompt_token_ids.tolist() + else: + prompt_token_ids = task.prompt_token_ids + input_ids = prompt_token_ids + task.output_token_ids req_id = task.request_id logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}") input_token_num = len(input_ids) @@ -546,9 +563,6 @@ class PrefixCacheManager: "request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache" ) - # record request cache info - self.cache_info[req_id] = (match_block_node, input_ids) - # 3. update metrics matched_token_num = gpu_match_token_num + cpu_match_token_num common_block_ids = match_gpu_block_ids + gpu_recv_block_ids @@ -571,6 +585,9 @@ class PrefixCacheManager: # set leaf node temporarily, then update it in update_cache_blocks self.req_leaf_map[req_id] = match_block_node self.leaf_req_map[match_block_node].add(req_id) + # record request cache info + self.cache_info[req_id] = (match_block_node, matched_token_num) + task.cached_block_num = matched_token_num // block_size return common_block_ids, matched_token_num, hit_info except Exception as e: logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}") @@ -687,6 +704,11 @@ class PrefixCacheManager: """ return self.executor_pool.submit(self.release_block_ids, task) + def free_block_ids(self, need_block_num): + self.free_block_ids_async(need_block_num) + while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()): + time.sleep(0.001) + def release_block_ids(self, task): """ release block ids @@ -1108,15 +1130,6 @@ class PrefixCacheManager: node.req_id_set.add(req_id) node = node.parent - def decrease_request_share_count(self, req_id): - """ - Decrease node shared count - """ - node, input_ids = self.cache_info[req_id] - while node != self.radix_tree_root: - node.decrement_shared_count() - node = node.parent - def build_path( self, req_id, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 929e093c4..ac1b92c9b 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -525,10 +525,13 @@ class EngineService: int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) + if self.cfg.model_config.enable_mm: + available_blocks = self.resource_manager.available_block_num() + else: + available_blocks = self.cfg.cache_config.max_block_num_per_seq - self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( - available_blocks=self.resource_manager.available_block_num(), + available_blocks=available_blocks, block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.max_model_len, diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 339f18f32..a6c7f355d 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -98,7 +98,6 @@ class ResourceManagerV1(ResourceManager): return len(request.block_tables) * self.config.cache_config.block_size def get_new_block_nums(self, request: Request, num_new_tokens: int): - self.check_and_free_block_tables() block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) @@ -137,7 +136,7 @@ class ResourceManagerV1(ResourceManager): preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 self._free_blocks(preempted_req) - preempted_req.prefill_block_num = None + preempted_req.cached_block_num = 0 self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) preempted_reqs.append(preempted_req) scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) @@ -300,14 +299,6 @@ class ResourceManagerV1(ResourceManager): if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 - else: # prefill finished - if ( - self.config.cache_config.enable_prefix_caching - and request.get("prefill_block_num", None) is None - ): - # update prefill cache blocks for prefix caching - request.prefill_block_num = len(request.block_tables) - self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size) if ( self.allocated_slots(request) - request.num_total_tokens <= self.config.cache_config.prealloc_dec_block_slot_num_threshold @@ -357,6 +348,10 @@ class ResourceManagerV1(ResourceManager): scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens + if self.config.cache_config.enable_prefix_caching: + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.num_computed_tokens + ) req_index += 1 # schedule the WAITING requests. if not preempted_reqs: @@ -371,6 +366,15 @@ class ResourceManagerV1(ResourceManager): if request.status == RequestStatus.WAITING: # Enable prefix caching if self.config.cache_config.enable_prefix_caching: + if ( + self.config.cache_config.enable_hierarchical_cache + and self.cache_manager.num_cpu_blocks > 0 + ): + if not self.cache_manager.can_allocate_gpu_blocks( + (request.need_prefill_tokens + self.config.cache_config.block_size - 1) + // self.config.cache_config.block_size + ): # to prevent block allocation for matching in hierarchical cache and cause dead lock + break success = self.get_prefix_cached_blocks(request) if not success: self._free_blocks(request) @@ -389,6 +393,10 @@ class ResourceManagerV1(ResourceManager): request.schedule_start_time = time.time() token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens + if self.config.cache_config.enable_prefix_caching: + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.num_computed_tokens + ) request.status = RequestStatus.RUNNING main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_running.inc(1) @@ -406,6 +414,15 @@ class ResourceManagerV1(ResourceManager): request.num_total_tokens ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct if self.config.cache_config.enable_prefix_caching: + if ( + self.config.cache_config.enable_hierarchical_cache + and self.cache_manager.num_cpu_blocks > 0 + ): + if not self.cache_manager.can_allocate_gpu_blocks( + (request.need_prefill_tokens + self.config.cache_config.block_size - 1) + // self.config.cache_config.block_size + ): # to prevent block allocation for matching in hierarchical cache and cause dead lock + break success = self.get_prefix_cached_blocks(request) if not success: self._free_blocks(request) @@ -421,6 +438,10 @@ class ResourceManagerV1(ResourceManager): scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens + if self.config.cache_config.enable_prefix_caching: + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.num_computed_tokens + ) request.status = RequestStatus.RUNNING main_process_metrics.num_requests_waiting.dec(1) main_process_metrics.num_requests_running.inc(1) @@ -516,7 +537,7 @@ class ResourceManagerV1(ResourceManager): matched_block_num = len(common_block_ids) no_cache_block_num = self.cache_manager.get_required_block_num( - request.prompt_token_ids_len - matched_token_num, + request.need_prefill_tokens - matched_token_num, self.config.cache_config.block_size, ) @@ -532,7 +553,7 @@ class ResourceManagerV1(ResourceManager): main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num) - if matched_token_num == request.prompt_token_ids_len: + if matched_token_num == request.need_prefill_tokens: request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size request.skip_allocate = True else: @@ -550,16 +571,8 @@ class ResourceManagerV1(ResourceManager): def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: - # TODO(chengyanfu): support cache output blocks for prefix caching - if request.get("prefill_block_num", None) is None: - leaf_node = self.cache_manager.req_leaf_map[request.request_id] - self.cache_manager.decrease_request_share_count(request.request_id) - self.cache_manager.free_nodes_directly(leaf_node) - self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :]) - - else: - self.cache_manager.release_block_ids_async(request) - self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :]) + self.cache_manager.release_block_ids(request) + self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :]) else: self.cache_manager.recycle_gpu_blocks(request.block_tables) request.block_tables = [] diff --git a/tests/v1/test_prefix_cache.py b/tests/v1/test_prefix_cache.py new file mode 100644 index 000000000..b2ded9018 --- /dev/null +++ b/tests/v1/test_prefix_cache.py @@ -0,0 +1,71 @@ +from dataclasses import asdict +from types import SimpleNamespace + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager +from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.request import Request + + +def test_normal_case(): + max_num_seqs = 3 + block_size = 64 + engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=100, max_num_batched_tokens=3200) + args = asdict(engine_args) + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=False) + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + max_num_batched_tokens=engine_args.max_num_batched_tokens, + ) + cache_manager = PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") + req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + req2 = Request.from_dict( + {"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200} + ) + req3 = Request.from_dict( + {"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200} + ) + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size) + assert len(common_block_ids) == 0 + assert matched_token_num == 0 + assert len(cache_manager.gpu_free_block_list) == 100 + req1.block_tables.extend(common_block_ids) + # allocate for req1 inputs + num_new_block = 50 + req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + req1.num_computed_tokens += 50 * block_size + cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 50 + # allocate for req2 inputs + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size) + assert len(common_block_ids) == 25 + assert matched_token_num == 25 * block_size + req2.num_cached_tokens = matched_token_num + req2.num_computed_tokens == 25 * block_size + num_new_block = 25 + req2.block_tables.extend(common_block_ids) + req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens) + # allocate for req3 input + (common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size) + assert len(common_block_ids) == 25 + assert matched_token_num == 25 * block_size + req3.num_cached_tokens = matched_token_num + req3.num_computed_tokens == 25 * block_size + assert len(cache_manager.gpu_free_block_list) == 25 + req3.block_tables.extend(common_block_ids) + num_new_block = 25 + assert cache_manager.can_allocate_gpu_blocks(num_new_block) + req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block)) + cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens) + assert len(cache_manager.gpu_free_block_list) == 0 diff --git a/tests/v1/test_schedule_output.py b/tests/v1/test_schedule_output.py new file mode 100644 index 000000000..ffe9432c3 --- /dev/null +++ b/tests/v1/test_schedule_output.py @@ -0,0 +1,128 @@ +from dataclasses import asdict +from types import SimpleNamespace + +from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.request import Request +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + + +def test_normal_schedule(): + max_num_seqs = 3 + engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=160, max_num_batched_tokens=3200) + args = asdict(engine_args) + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=False) + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + speculative_config=speculative_cfg, + graph_opt_config=graph_opt_cfg, + max_num_batched_tokens=engine_args.max_num_batched_tokens, + ) + resource_manager_v1 = ResourceManagerV1( + max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" + ) + req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3199, "prompt_token_ids_len": 3199}) + req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3201, "prompt_token_ids_len": 3201}) + req3 = Request.from_dict({"request_id": "req3", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + resource_manager_v1.add_request(req1) + resource_manager_v1.add_request(req2) + resource_manager_v1.add_request(req3) + # step 1 + assert len(resource_manager_v1.waiting) == 3 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 2 + assert scheduler_reqs[0].request_id == "req1" + assert scheduler_reqs[1].request_id == "req2" + assert scheduler_reqs[0].prefill_start_index == 0 + assert scheduler_reqs[1].prefill_start_index == 0 + assert scheduler_reqs[0].prefill_end_index == 3199 + assert scheduler_reqs[1].prefill_end_index == 1 + assert len(resource_manager_v1.running) == 2 + assert len(resource_manager_v1.waiting) == 1 + # step 2 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 2 + assert scheduler_reqs[0].request_id == "req1" + assert len(scheduler_reqs[0].block_tables) == 52 + assert scheduler_reqs[1].request_id == "req2" + assert scheduler_reqs[1].prefill_start_index == 1 + assert scheduler_reqs[1].prefill_end_index == 3200 + assert len(resource_manager_v1.running) == 2 + assert len(resource_manager_v1.waiting) == 1 + # step 3 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 2 + assert scheduler_reqs[0].request_id == "req2" + assert scheduler_reqs[0].prefill_start_index == 3200 + assert scheduler_reqs[0].prefill_end_index == 3201 + assert scheduler_reqs[1].request_id == "req3" + assert scheduler_reqs[1].prefill_start_index == 0 + assert scheduler_reqs[1].prefill_end_index == 3199 + assert len(resource_manager_v1.running) == 3 + assert len(resource_manager_v1.waiting) == 0 + + +def test_preempted_request(): + max_num_seqs = 2 + engine_args = EngineArgs(max_num_seqs=max_num_seqs, num_gpu_blocks_override=52, max_num_batched_tokens=3200) + args = asdict(engine_args) + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=False) + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + max_num_batched_tokens=engine_args.max_num_batched_tokens, + ) + resource_manager_v1 = ResourceManagerV1( + max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" + ) + req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + req2 = Request.from_dict({"request_id": "req2", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200}) + resource_manager_v1.add_request(req1) + resource_manager_v1.add_request(req2) + # step 1 + assert len(resource_manager_v1.waiting) == 2 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 1 + assert scheduler_reqs[0].request_id == "req1" + assert scheduler_reqs[0].prefill_start_index == 0 + assert scheduler_reqs[0].prefill_end_index == 3200 + assert len(resource_manager_v1.running) == 1 + assert len(resource_manager_v1.waiting) == 1 + # step 2 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 1 + assert scheduler_reqs[0].request_id == "req1" + assert len(scheduler_reqs[0].block_tables) == 52 + # step 3 + req1.output_token_ids.extend([1] * 128) + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 1 + assert scheduler_reqs[0].request_id == "req1" + assert len(resource_manager_v1.running) == 0 + # to be added into waiting queue + assert len(resource_manager_v1.waiting) == 1 + # mock token_processor to add into waiting + resource_manager_v1.waiting.appendleft(req1) + # step 4 + scheduler_reqs = resource_manager_v1.schedule() + assert len(scheduler_reqs) == 1 + assert scheduler_reqs[0].request_id == "req1" + assert len(resource_manager_v1.running) == 1 + assert len(resource_manager_v1.waiting) == 1