[Cherry-Pick][BugFix] Fix async download(#5349) (#5347)

* fix mm to_dict bug

* pd support async download

* update code

* update test case

* update log

* Revert "update log"

This reverts commit 6e883150cd.

* update code

* fix mtp bug
This commit is contained in:
kevin
2025-12-05 18:59:36 +08:00
committed by GitHub
parent cae2c1ccf5
commit 9b5b08cb72
6 changed files with 57 additions and 20 deletions

View File

@@ -486,8 +486,6 @@ class EngineArgs:
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
self.enable_prefix_caching = False
# if self.dynamic_load_weight:

View File

@@ -718,6 +718,10 @@ class EngineService:
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
@@ -770,15 +774,36 @@ class EngineService:
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.info(f"get_finished_add_cache_task_req: {finished_ids}")
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)
for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:

View File

@@ -183,6 +183,22 @@ class Request:
self.error_message = None
self.error_code = None
def __getstate__(self):
"""
Custom getstate method for pickle support.
Handles unpicklable attributes by filtering them from __dict__.
"""
# Create a filtered dictionary without problematic attributes
filtered_dict = {}
for key, value in self.__dict__.items():
# Skip attributes that are known to contain unpicklable objects
if key == "async_process_futures":
filtered_dict[key] = []
else:
filtered_dict[key] = value
return filtered_dict
@classmethod
def from_dict(cls, d: dict):
data_processor_logger.debug(f"{d}")

View File

@@ -653,7 +653,7 @@ class ResourceManagerV1(ResourceManager):
):
break
if request.status == RequestStatus.WAITING:
result = self._waiting_async_process(request)
result = self.waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
@@ -761,7 +761,7 @@ class ResourceManagerV1(ResourceManager):
return scheduled_reqs, error_reqs
def _waiting_async_process(self, request: Request) -> None:
def waiting_async_process(self, request: Request) -> None:
"""
Check if async preprocessing is complete for a request.
Args:
@@ -780,7 +780,7 @@ class ResourceManagerV1(ResourceManager):
request.async_process_futures = []
return False
def _apply_async_preprocess(self, request: Request) -> None:
def apply_async_preprocess(self, request: Request) -> None:
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))
def _has_features_info(self, task):
@@ -903,7 +903,7 @@ class ResourceManagerV1(ResourceManager):
def add_request(self, request: Request) -> None:
with self.lock:
self._apply_async_preprocess(request)
self.apply_async_preprocess(request)
self.waiting.append(request)
self.requests[request.request_id] = request

View File

@@ -176,9 +176,7 @@ class MTPProposer(Proposer):
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
if not profile and self.scheduler_config.splitwise_role != "mixed":
cache_kvs_list = []
for i in range(
self.num_main_model_layers,

View File

@@ -54,7 +54,7 @@ class TestResourceManagerV1(unittest.TestCase):
def test_waiting_async_process_no_futures(self):
"""Test when there are no async process futures"""
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)
def test_waiting_async_process_future_done_no_error(self):
@@ -63,7 +63,7 @@ class TestResourceManagerV1(unittest.TestCase):
future.set_result(True)
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)
self.assertEqual(len(self.request.async_process_futures), 0)
@@ -74,7 +74,7 @@ class TestResourceManagerV1(unittest.TestCase):
self.request.async_process_futures = [future]
self.request.error_message = "Download failed"
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertIsNone(result)
def test_waiting_async_process_future_not_done(self):
@@ -82,7 +82,7 @@ class TestResourceManagerV1(unittest.TestCase):
future = concurrent.futures.Future()
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertTrue(result)
self.assertEqual(len(self.request.async_process_futures), 1)
@@ -90,7 +90,7 @@ class TestResourceManagerV1(unittest.TestCase):
"""Test applying async preprocess"""
with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit:
mock_submit.return_value = "mock_future"
self.manager._apply_async_preprocess(self.request)
self.manager.apply_async_preprocess(self.request)
mock_submit.assert_called_once_with(self.manager._download_features, self.request)
self.assertEqual(len(self.request.async_process_futures), 1)