From 9b5b08cb72b80871a11e7fbebf792062074f3a94 Mon Sep 17 00:00:00 2001 From: kevin Date: Fri, 5 Dec 2025 18:59:36 +0800 Subject: [PATCH] [Cherry-Pick][BugFix] Fix async download(#5349) (#5347) * fix mm to_dict bug * pd support async download * update code * update test case * update log * Revert "update log" This reverts commit 6e883150cd5730780d702a8982850bd9e6d57e93. * update code * fix mtp bug --- fastdeploy/engine/args_utils.py | 2 - fastdeploy/engine/common_engine.py | 37 ++++++++++++++++--- fastdeploy/engine/request.py | 16 ++++++++ .../engine/sched/resource_manager_v1.py | 8 ++-- fastdeploy/spec_decode/mtp.py | 4 +- tests/v1/test_resource_manager_v1.py | 10 ++--- 6 files changed, 57 insertions(+), 20 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 77def0feb..3b5ed41d9 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -486,8 +486,6 @@ class EngineArgs: self.tokenizer = self.model if self.splitwise_role == "decode": self.enable_prefix_caching = False - if self.speculative_config is not None: - self.enable_prefix_caching = False if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu(): self.enable_prefix_caching = False # if self.dynamic_load_weight: diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index b7a96489f..1bf27b5c0 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -718,6 +718,10 @@ class EngineService: self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}") if self.cfg.scheduler_config.splitwise_role != "mixed": + if self.cfg.scheduler_config.splitwise_role == "prefill": + for task in tasks: + # start async preprocess + self.resource_manager.apply_async_preprocess(task) need_delete_tasks = [] if envs.FD_OFFLINE_PERF_TEST_FOR_PD: for task in tasks: @@ -770,15 +774,36 @@ class EngineService: self.split_connector.send_cache_info_to_messager(tasks, 0) # ensure cache tasks has sent to cache_messager need_check_req_ids = [task.request_id for task in tasks] + finished_ids, delete_tasks_list = [], [] while need_check_req_ids: - req_ids = self.engine_worker_queue.get_finished_add_cache_task_req() - self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}") - if req_ids: - for req_id in req_ids: - assert req_id in need_check_req_ids - need_check_req_ids.remove(req_id) + finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) + self.llm_logger.info(f"get_finished_add_cache_task_req: {finished_ids}") + if finished_ids: + for task in tasks: + result = self.resource_manager.waiting_async_process(task) + if result is None: + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + delete_tasks_list.append(task) + elif result is False: + if task.request_id in finished_ids: + need_check_req_ids.remove(task.request_id) + finished_ids.remove(task.request_id) else: time.sleep(0.001) + + for tmp_task in delete_tasks_list: + tasks.remove(tmp_task) + # release resource in P + self.resource_manager.pre_recycle_resource(tmp_task.request_id) # Fetch requests and add them to the scheduling queue if tasks: for task in tasks: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 9b26f8f48..4e3d2b04e 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -183,6 +183,22 @@ class Request: self.error_message = None self.error_code = None + def __getstate__(self): + """ + Custom getstate method for pickle support. + Handles unpicklable attributes by filtering them from __dict__. + """ + # Create a filtered dictionary without problematic attributes + filtered_dict = {} + for key, value in self.__dict__.items(): + # Skip attributes that are known to contain unpicklable objects + if key == "async_process_futures": + filtered_dict[key] = [] + else: + filtered_dict[key] = value + + return filtered_dict + @classmethod def from_dict(cls, d: dict): data_processor_logger.debug(f"{d}") diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index fbe30d24b..12ae16e86 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -653,7 +653,7 @@ class ResourceManagerV1(ResourceManager): ): break if request.status == RequestStatus.WAITING: - result = self._waiting_async_process(request) + result = self.waiting_async_process(request) if result is None: error_reqs.append((request.request_id, request.error_message)) self.waiting.popleft() @@ -761,7 +761,7 @@ class ResourceManagerV1(ResourceManager): return scheduled_reqs, error_reqs - def _waiting_async_process(self, request: Request) -> None: + def waiting_async_process(self, request: Request) -> None: """ Check if async preprocessing is complete for a request. Args: @@ -780,7 +780,7 @@ class ResourceManagerV1(ResourceManager): request.async_process_futures = [] return False - def _apply_async_preprocess(self, request: Request) -> None: + def apply_async_preprocess(self, request: Request) -> None: request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request)) def _has_features_info(self, task): @@ -903,7 +903,7 @@ class ResourceManagerV1(ResourceManager): def add_request(self, request: Request) -> None: with self.lock: - self._apply_async_preprocess(request) + self.apply_async_preprocess(request) self.waiting.append(request) self.requests[request.request_id] = request diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 611c3ab5f..420f14721 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -176,9 +176,7 @@ class MTPProposer(Proposer): if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not profile and ( - self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" - ): + if not profile and self.scheduler_config.splitwise_role != "mixed": cache_kvs_list = [] for i in range( self.num_main_model_layers, diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 6d2ae88e9..038a18b40 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -54,7 +54,7 @@ class TestResourceManagerV1(unittest.TestCase): def test_waiting_async_process_no_futures(self): """Test when there are no async process futures""" - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertFalse(result) def test_waiting_async_process_future_done_no_error(self): @@ -63,7 +63,7 @@ class TestResourceManagerV1(unittest.TestCase): future.set_result(True) self.request.async_process_futures = [future] - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertFalse(result) self.assertEqual(len(self.request.async_process_futures), 0) @@ -74,7 +74,7 @@ class TestResourceManagerV1(unittest.TestCase): self.request.async_process_futures = [future] self.request.error_message = "Download failed" - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertIsNone(result) def test_waiting_async_process_future_not_done(self): @@ -82,7 +82,7 @@ class TestResourceManagerV1(unittest.TestCase): future = concurrent.futures.Future() self.request.async_process_futures = [future] - result = self.manager._waiting_async_process(self.request) + result = self.manager.waiting_async_process(self.request) self.assertTrue(result) self.assertEqual(len(self.request.async_process_futures), 1) @@ -90,7 +90,7 @@ class TestResourceManagerV1(unittest.TestCase): """Test applying async preprocess""" with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit: mock_submit.return_value = "mock_future" - self.manager._apply_async_preprocess(self.request) + self.manager.apply_async_preprocess(self.request) mock_submit.assert_called_once_with(self.manager._download_features, self.request) self.assertEqual(len(self.request.async_process_futures), 1)