diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index f8cf662ef..1eaf53549 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -500,8 +500,6 @@ class EngineArgs: self.tokenizer = self.model if self.splitwise_role == "decode": self.enable_prefix_caching = False - if self.speculative_config is not None: - self.enable_prefix_caching = False if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu(): self.enable_prefix_caching = False # if self.dynamic_load_weight: diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 5e152f746..350eb08cd 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -722,6 +722,10 @@ class EngineService: ) if self.cfg.scheduler_config.splitwise_role != "mixed": + if self.cfg.scheduler_config.splitwise_role == "prefill": + for task in tasks: + # start async preprocess + self.resource_manager.apply_async_preprocess(task) need_delete_tasks = [] if envs.FD_OFFLINE_PERF_TEST_FOR_PD: for task in tasks: @@ -782,18 +786,39 @@ class EngineService: need_check_req_ids = [task.request_id for task in tasks] self.split_connector.send_cache_info_to_messager(tasks, 0) # ensure cache tasks has sent to cache_messager + need_check_req_ids = [task.request_id for task in tasks] + finished_ids, delete_tasks_list = [], [] while need_check_req_ids: - req_ids = self.engine_worker_queue.get_finished_add_cache_task_req() - if req_ids: - self.llm_logger.debug( - f"P has successfully sent cache infos to cache messager for requests: {req_ids}" - ) - for req_id in req_ids: - assert req_id in need_check_req_ids - need_check_req_ids.remove(req_id) + finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) + self.llm_logger.debug( + f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" + ) + if finished_ids: + for task in tasks: + result = self.resource_manager.waiting_async_process(task) + if result is None: + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + delete_tasks_list.append(task) + elif result is False: + if task.request_id in finished_ids: + need_check_req_ids.remove(task.request_id) + finished_ids.remove(task.request_id) else: time.sleep(0.001) + for tmp_task in delete_tasks_list: + tasks.remove(tmp_task) + # release resource in P + self.resource_manager.pre_recycle_resource(tmp_task.request_id) # Fetch requests and add them to the scheduling queue if tasks: for task in tasks: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 34935a51b..9f281c3e6 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -255,6 +255,22 @@ class Request: """ return self.prompt_token_ids_len + len(self.output_token_ids) + def __getstate__(self): + """ + Custom getstate method for pickle support. + Handles unpicklable attributes by filtering them from __dict__. + """ + # Create a filtered dictionary without problematic attributes + filtered_dict = {} + for key, value in self.__dict__.items(): + # Skip attributes that are known to contain unpicklable objects + if key == "async_process_futures": + filtered_dict[key] = [] + else: + filtered_dict[key] = value + + return filtered_dict + def __eq__(self, other): """ EQ operator. diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 2494881da..440acb810 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -661,7 +661,7 @@ class ResourceManagerV1(ResourceManager): ) or (paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)): break if request.status == RequestStatus.WAITING: - result = self._waiting_async_process(request) + result = self.waiting_async_process(request) if result is None: error_reqs.append((request.request_id, request.error_message)) self.waiting.popleft() @@ -768,7 +768,7 @@ class ResourceManagerV1(ResourceManager): return scheduled_reqs, error_reqs - def _waiting_async_process(self, request: Request) -> None: + def waiting_async_process(self, request: Request) -> None: """ Check if async preprocessing is complete for a request. Args: @@ -787,7 +787,7 @@ class ResourceManagerV1(ResourceManager): request.async_process_futures = [] return False - def _apply_async_preprocess(self, request: Request) -> None: + def apply_async_preprocess(self, request: Request) -> None: request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request)) def _has_features_info(self, task): @@ -917,7 +917,7 @@ class ResourceManagerV1(ResourceManager): def add_request(self, request: Request) -> None: with self.lock: - self._apply_async_preprocess(request) + self.apply_async_preprocess(request) self.waiting.append(request) self.requests[request.request_id] = request diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 7bf76ed92..1ecda17e2 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -203,9 +203,7 @@ class MTPProposer(Proposer): if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not profile and ( - self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" - ): + if not profile and self.scheduler_config.splitwise_role != "mixed": cache_kvs_list = [] for i in range( self.num_main_model_layers, diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 6d2ae88e9..038a18b40 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -54,7 +54,7 @@ class TestResourceManagerV1(unittest.TestCase): def test_waiting_async_process_no_futures(self): """Test when there are no async process futures""" - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertFalse(result) def test_waiting_async_process_future_done_no_error(self): @@ -63,7 +63,7 @@ class TestResourceManagerV1(unittest.TestCase): future.set_result(True) self.request.async_process_futures = [future] - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertFalse(result) self.assertEqual(len(self.request.async_process_futures), 0) @@ -74,7 +74,7 @@ class TestResourceManagerV1(unittest.TestCase): self.request.async_process_futures = [future] self.request.error_message = "Download failed" - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertIsNone(result) def test_waiting_async_process_future_not_done(self): @@ -82,7 +82,7 @@ class TestResourceManagerV1(unittest.TestCase): future = concurrent.futures.Future() self.request.async_process_futures = [future] - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertTrue(result) self.assertEqual(len(self.request.async_process_futures), 1) @@ -90,7 +90,7 @@ class TestResourceManagerV1(unittest.TestCase): """Test applying async preprocess""" with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit: mock_submit.return_value = "mock_future" - self.manager._apply_async_preprocess(self.request) + self.manager.apply_async_preprocess(self.request) mock_submit.assert_called_once_with(self.manager._download_features, self.request) self.assertEqual(len(self.request.async_process_futures), 1)