fix async download bug

This commit is contained in:
kevin
2025-12-03 11:37:31 +08:00
parent dfeabee123
commit f87e1900ec
4 changed files with 40 additions and 17 deletions

View File

@@ -496,8 +496,6 @@ class EngineArgs:
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
self.enable_prefix_caching = False
# if self.dynamic_load_weight:

View File

@@ -718,6 +718,10 @@ class EngineService:
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
@@ -770,15 +774,36 @@ class EngineService:
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.info(f"get_finished_add_cache_task_req: {finished_ids}")
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)
for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:

View File

@@ -652,7 +652,7 @@ class ResourceManagerV1(ResourceManager):
):
break
if request.status == RequestStatus.WAITING:
result = self._waiting_async_process(request)
result = self.waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
@@ -759,7 +759,7 @@ class ResourceManagerV1(ResourceManager):
return scheduled_reqs, error_reqs
def _waiting_async_process(self, request: Request) -> None:
def waiting_async_process(self, request: Request) -> None:
"""
Check if async preprocessing is complete for a request.
Args:
@@ -778,7 +778,7 @@ class ResourceManagerV1(ResourceManager):
request.async_process_futures = []
return False
def _apply_async_preprocess(self, request: Request) -> None:
def apply_async_preprocess(self, request: Request) -> None:
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))
def _has_features_info(self, task):
@@ -908,7 +908,7 @@ class ResourceManagerV1(ResourceManager):
def add_request(self, request: Request) -> None:
with self.lock:
self._apply_async_preprocess(request)
self.apply_async_preprocess(request)
self.waiting.append(request)
self.requests[request.request_id] = request

View File

@@ -54,7 +54,7 @@ class TestResourceManagerV1(unittest.TestCase):
def test_waiting_async_process_no_futures(self):
"""Test when there are no async process futures"""
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)
def test_waiting_async_process_future_done_no_error(self):
@@ -63,7 +63,7 @@ class TestResourceManagerV1(unittest.TestCase):
future.set_result(True)
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)
self.assertEqual(len(self.request.async_process_futures), 0)
@@ -74,7 +74,7 @@ class TestResourceManagerV1(unittest.TestCase):
self.request.async_process_futures = [future]
self.request.error_message = "Download failed"
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertIsNone(result)
def test_waiting_async_process_future_not_done(self):
@@ -82,7 +82,7 @@ class TestResourceManagerV1(unittest.TestCase):
future = concurrent.futures.Future()
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertTrue(result)
self.assertEqual(len(self.request.async_process_futures), 1)
@@ -90,7 +90,7 @@ class TestResourceManagerV1(unittest.TestCase):
"""Test applying async preprocess"""
with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit:
mock_submit.return_value = "mock_future"
self.manager._apply_async_preprocess(self.request)
self.manager.apply_async_preprocess(self.request)
mock_submit.assert_called_once_with(self.manager._download_features, self.request)
self.assertEqual(len(self.request.async_process_futures), 1)