From 109d48e456d2f22b2b9804defa6cb0958fd872de Mon Sep 17 00:00:00 2001 From: kevin Date: Wed, 19 Nov 2025 22:23:36 +0800 Subject: [PATCH] [Feature] support async download features (#5003) * support async download features * add test case * update code --- fastdeploy/config.py | 2 + fastdeploy/engine/args_utils.py | 11 ++ fastdeploy/engine/common_engine.py | 77 +++------ fastdeploy/engine/request.py | 4 + .../engine/sched/resource_manager_v1.py | 111 +++++++++++- .../inter_communicator/engine_worker_queue.py | 49 ++++-- fastdeploy/utils.py | 31 ++++ tests/inter_communicator/test_e2w_queue.py | 47 ++++- tests/v1/test_resource_manager_v1.py | 162 ++++++++++++++++++ tests/v1/test_schedule_output.py | 14 +- 10 files changed, 433 insertions(+), 75 deletions(-) create mode 100644 tests/v1/test_resource_manager_v1.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index afe37f076..83af4ebdd 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -550,6 +550,8 @@ class ParallelConfig: self.use_internode_ll_two_stage: bool = False # disable sequence parallel moe self.disable_sequence_parallel_moe: bool = False + # enable async download features + self.enable_async_download_features: bool = False self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 5af1ddfd2..23812e966 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -467,6 +467,11 @@ class EngineArgs: Url for router server, such as `0.0.0.0:30000`. """ + enable_async_download_features: bool = False + """ + Flag to enable async download features. Default is False (disabled). + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -849,6 +854,12 @@ class EngineArgs: default=EngineArgs.enable_expert_parallel, help="Enable expert parallelism.", ) + parallel_group.add_argument( + "--enable-async-download-features", + action="store_true", + default=EngineArgs.enable_async_download_features, + help="Enable async download features.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index d4c86c98f..9f0db935c 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -51,14 +51,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.trace.constants import LoggingEventName from fastdeploy.trace.trace_logger import print as trace_print -from fastdeploy.utils import ( - EngineError, - check_download_links, - envs, - get_logger, - init_bos_client, - llm_logger, -) +from fastdeploy.utils import EngineError, envs, get_logger, llm_logger try: TokenProcessor = load_token_processor_plugins() @@ -808,7 +801,7 @@ class EngineService: else: raise # 2. Schedule requests - tasks = self.resource_manager.schedule() + tasks, error_tasks = self.resource_manager.schedule() # 3. Send to engine if tasks: @@ -833,7 +826,16 @@ class EngineService: trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "")) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - else: + + # 4. Response error tasks + if error_tasks: + for request_id, failed in error_tasks: + if failed is None: + llm_logger.warning(f"Request {request_id} has no error, skip sending error response.") + continue + self._send_error_response(request_id, failed) + + if not tasks and not error_tasks: time.sleep(0.005) except RuntimeError as e: @@ -909,24 +911,6 @@ class EngineService: self.llm_logger.error(f"Receive request error: {err_msg}") results.append((request.request_id, err_msg)) - if self._has_features_info(request) and err_msg is None: - if self.bos_client is None: - self.bos_client = init_bos_client() - - download_urls = [] - inputs = request.multimodal_inputs - if inputs.get("video_feature_urls") is not None: - download_urls.extend(inputs.get("video_feature_urls")) - if inputs.get("image_feature_urls") is not None: - download_urls.extend(inputs.get("image_feature_urls")) - if inputs.get("audio_feature_urls") is not None: - download_urls.extend(inputs.get("audio_feature_urls")) - - err_msg = check_download_links(self.bos_client, download_urls) - if err_msg: - llm_logger.error(f"Receive request {request.request_id} download error: {err_msg}") - results.append((request.request_id, err_msg)) - if err_msg is None: insert_task.append(request) @@ -948,21 +932,27 @@ class EngineService: main_process_metrics.num_requests_waiting.inc(1) continue - error_result = RequestOutput( - request_id=request_id, - finished=True, - error_code=500, - error_msg=failed, - ) - # Since the request is not in scheduler - # Send result by zmq directly - self.send_response_server.send_response(request_id, [error_result]) + self._send_error_response(request_id, failed) except Exception as e: self.llm_logger.error( f"Error happened while receiving new request from zmq, details={e}, " f"traceback={traceback.format_exc()}" ) + def _send_error_response(self, request_id, error_msg, error_code: int = 500): + llm_logger.error( + f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}" + ) + error_result = RequestOutput( + request_id=request_id, + finished=True, + error_code=error_code, + error_msg=error_msg, + ) + # Since the request is not in scheduler + # Send result by zmq directly + self.send_response_server.send_response(request_id, [error_result]) + def _decode_token(self, token_ids, req_id, is_end): delta_text = "" if envs.FD_ENABLE_RETURN_TEXT: @@ -977,19 +967,6 @@ class EngineService: del self.data_processor.decode_status[req_id] return delta_text, token_ids - def _has_features_info(self, task): - inputs = task.multimodal_inputs - if inputs is None or len(inputs) == 0: - return False - - if ( - (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) - or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) - or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) - ): - return True - return False - def _zmq_send_generated_tokens(self): """ Recieve output for zmq diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 70d82d2e3..f71eca7db 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -173,6 +173,10 @@ class Request: # dp self.dp_rank = dp_rank + self.async_process_futures = [] + self.error_message = None + self.error_code = None + @classmethod def from_dict(cls, d: dict): data_processor_logger.debug(f"{d}") diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e7da5422f..b74c772d3 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import IPCSignal from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.platforms import current_platform -from fastdeploy.utils import llm_logger +from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger @dataclass @@ -195,6 +195,9 @@ class ResourceManagerV1(ResourceManager): max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) + self.bos_client = None + self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -500,6 +503,7 @@ class ResourceManagerV1(ResourceManager): with self.lock: scheduled_reqs: list[Request] = [] preempted_reqs: list[Request] = [] + error_reqs: list[tuple[str, str]] = [] token_budget = self.config.scheduler_config.max_num_batched_tokens # First, schedule the RUNNING requests. @@ -629,6 +633,7 @@ class ResourceManagerV1(ResourceManager): req_index += 1 # schedule the WAITING requests. if not preempted_reqs: + skip_requests: list[Request] = [] while self.waiting and token_budget > 0: if len(self.running) == self.max_num_seqs: break @@ -639,6 +644,17 @@ class ResourceManagerV1(ResourceManager): ): break if request.status == RequestStatus.WAITING: + result = self._waiting_async_process(request) + if result is None: + error_reqs.append((request.request_id, request.error_message)) + self.waiting.popleft() + continue + elif result is True: + # skip current request, try next request + skip_requests.append(request) + self.waiting.popleft() + continue + self._update_mm_hashes(request) # Enable prefix caching if self.config.cache_config.enable_prefix_caching: @@ -725,12 +741,102 @@ class ResourceManagerV1(ResourceManager): else: llm_logger.error("Unknown request status type") + for req in skip_requests: + # move waiting request to end of the deque + self.waiting.append(req) + if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") self.update_metrics() - return scheduled_reqs + return scheduled_reqs, error_reqs + + def _waiting_async_process(self, request: Request) -> None: + """ + Check if async preprocessing is complete for a request. + Args: + request: The request to check + Returns: + None: If an error occurred during preprocessing + True: If preprocessing is still in progress (request should be skipped) + False: If preprocessing is complete (request can be scheduled) + """ + for future in request.async_process_futures: + if future.done(): + if request.get("error_message") is not None: + return None + else: + return True + request.async_process_futures = [] + return False + + def _apply_async_preprocess(self, request: Request) -> None: + request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request)) + + def _has_features_info(self, task): + inputs = task.multimodal_inputs + if inputs is None or len(inputs) == 0: + return False + + if ( + (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) + or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) + or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) + ): + return True + return False + + def _download_features(self, request: Request) -> None: + """ + download multimodal features from bos + Note: + 1. this function will be add features for request.multimodal_inputs + 2. this function maybe update request.error_message and request.error_code + Args: + request (Request): request object + """ + + def download_bos_features(bos_client, features_urls): + result_list = [] + for status, feature in download_from_bos(self.bos_client, features_urls): + if status: + llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}") + result_list.append(feature) + else: + error_msg = f"request {request.request_id} download features error: {feature}" + llm_logger.error(error_msg) + return error_msg + return result_list + + if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request): + return None + + if self.bos_client is None: + self.bos_client = init_bos_client() + + inputs = request.multimodal_inputs + if inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0: + result = download_bos_features(self.bos_client, inputs["video_feature_urls"]) + if isinstance(result, str): # download error + request.error_message = result + request.error_code = 530 + return None + inputs["video_features"] = result + if inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0: + result = download_bos_features(self.bos_client, inputs["image_feature_urls"]) + if isinstance(result, str): # download error + request.error_message = result + request.error_code = 530 + return None + inputs["image_features"] = result + if inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0: + result = download_bos_features(self.bos_client, inputs["audio_feature_urls"]) + if isinstance(result, str): # download error + request.error_message = result + request.error_code = 530 + return None + inputs["audio_features"] = result def get_available_position(self) -> int: position = 0 @@ -788,6 +894,7 @@ class ResourceManagerV1(ResourceManager): def add_request(self, request: Request) -> None: with self.lock: + self._apply_async_preprocess(request) self.waiting.append(request) self.requests[request.request_id] = request diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index f91638bb5..1171e6fb6 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -499,7 +499,13 @@ class EngineWorkerQueue: "attention_mask_offset", ] - llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys}") + list_keys = [ + "image_features", + "video_features", + "audio_features", + ] + + llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys + list_keys}") for key in tensor_keys: value = multimodal_inputs.get(key) @@ -507,6 +513,13 @@ class EngineWorkerQueue: continue if not isinstance(value, paddle.Tensor): multimodal_inputs[key] = paddle.to_tensor(value) + + for key in list_keys: + value = multimodal_inputs.get(key) + if value is None: + continue + if isinstance(value, list): + multimodal_inputs[key] = [paddle.to_tensor(v) for v in value] except Exception as e: llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}") @@ -518,16 +531,30 @@ class EngineWorkerQueue: Args: tasks: List of tasks containing multimodal inputs. """ + if (not envs.FD_ENABLE_MAX_PREFILL) and (not envs.FD_ENABLE_E2W_TENSOR_CONVERT): + return + try: - if envs.FD_ENABLE_MAX_PREFILL: - for batch_tasks, _ in tasks: - for task in batch_tasks: - if not hasattr(task, "multimodal_inputs"): - continue - images = task.multimodal_inputs.get("images", None) - if isinstance(images, paddle.Tensor): - llm_logger.debug(f"Convert image to numpy, shape: {images.shape}") - task.multimodal_inputs["images"] = images.numpy() + batch_tasks, _ = tasks + for task in batch_tasks: + if not hasattr(task, "multimodal_inputs"): + continue + images = task.multimodal_inputs.get("images", None) + if isinstance(images, paddle.Tensor): + llm_logger.debug(f"Convert image to numpy, shape: {images.shape}") + task.multimodal_inputs["images"] = images.numpy() + + list_keys = [ + "image_features", + "video_features", + "audio_features", + ] + for key in list_keys: + value = task.multimodal_inputs.get(key, None) + if value is None: + continue + if isinstance(value, list): + task.multimodal_inputs[key] = [v.numpy() for v in value] except Exception as e: llm_logger.warning(f"Failed to convert to numpy: {e}") @@ -565,7 +592,7 @@ class EngineWorkerQueue: tasks.extend(self.tasks) # 多模态输入转换为numpy - # EngineWorkerQueue.to_numpy(tasks) + EngineWorkerQueue.to_numpy(tasks) self.client_read_flag[self.client_id] = 1 all_client_read: bool = np.sum(self.client_read_flag) == self.num_client diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index b3b2b21db..bb20fa43f 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -21,6 +21,7 @@ import importlib import json import logging import os +import pickle import random import re import socket @@ -975,6 +976,36 @@ def init_bos_client(): return BosClient(cfg) +def download_from_bos(bos_client, bos_links): + """ + Download pickled objects from Baidu Object Storage (BOS). + Args: + bos_client: BOS client instance + bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object" + Yields: + tuple: (success: bool, data: np.ndarray | error_msg: str) + - On success: (True, deserialized_data) + - On failure: (False, error_message) and stops processing remaining links + Security Note: + Uses pickle deserialization. Only use with trusted data sources. + """ + if not isinstance(bos_links, list): + bos_links = [bos_links] + + for link in bos_links: + try: + if link.startswith("bos://"): + link = link.replace("bos://", "") + + bucket_name = "/".join(link.split("/")[1:-1]) + object_key = link.split("/")[-1] + response = bos_client.get_object_as_string(bucket_name, object_key) + yield True, pickle.loads(response) + except Exception as e: + yield False, f"link {link} download error: {str(e)}" + break + + llm_logger = get_logger("fastdeploy", "fastdeploy.log") data_processor_logger = get_logger("data_processor", "data_processor.log") scheduler_logger = get_logger("scheduler", "scheduler.log") diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index 7e497e331..f24d95286 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -20,6 +20,7 @@ import numpy as np import paddle from fastdeploy import envs +from fastdeploy.engine.request import Request from fastdeploy.inter_communicator.engine_worker_queue import EngineWorkerQueue @@ -80,7 +81,7 @@ class TestEngineWorkerQueue(unittest.TestCase): # 构造 paddle.Tensor 输入 tensor_images = paddle.randn([2, 3, 224, 224]) task = DummyTask(tensor_images) - tasks = [([task], 1)] + tasks = ([task], 1) EngineWorkerQueue.to_numpy(tasks) @@ -94,7 +95,7 @@ class TestEngineWorkerQueue(unittest.TestCase): tensor_images = paddle.randn([2, 3, 224, 224]) # 创建模拟任务 task = DummyTask(tensor_images) - tasks = [([task], 1)] + tasks = ([task], 1) # 调用转换方法(预期不会转换) EngineWorkerQueue.to_numpy(tasks) @@ -107,7 +108,7 @@ class TestEngineWorkerQueue(unittest.TestCase): pass task = NoMMTask() - tasks = [([task], 1)] + tasks = ([task], 1) # 不应抛异常 try: @@ -119,7 +120,7 @@ class TestEngineWorkerQueue(unittest.TestCase): envs.FD_ENABLE_MAX_PREFILL = 1 np_images = np.random.randn(2, 3, 224, 224) task = DummyTask(np_images) - tasks = [([task], 1)] + tasks = ([task], 1) EngineWorkerQueue.to_numpy(tasks) @@ -135,13 +136,49 @@ class TestEngineWorkerQueue(unittest.TestCase): raise RuntimeError("mock error") bad_task = DummyTask(images=BadTensor()) - bad_tasks = [([bad_task], 1)] + bad_tasks = ([bad_task], 1) try: EngineWorkerQueue.to_numpy(bad_tasks) except Exception as e: self.fail(f"Exception should be handled internally, but got: {e}") + def test_features_info_to_tensor(self): + envs.FD_ENABLE_MAX_PREFILL = 1 + np_feature = paddle.randn([2, 3, 224, 224]).numpy() + multimodal_inputs = { + "image_features": [np_feature, np_feature], + } + req_dict = { + "request_id": "req1", + "multimodal_inputs": multimodal_inputs, + } + task = Request.from_dict(req_dict) + EngineWorkerQueue.to_tensor(([task], 1)) + + # 验证已转换为tensor + self.assertEqual(len(task.multimodal_inputs["image_features"]), 2) + self.assertIsInstance(task.multimodal_inputs["image_features"][0], paddle.Tensor) + self.assertIsInstance(task.multimodal_inputs["image_features"][1], paddle.Tensor) + + def test_features_info_to_numpy(self): + envs.FD_ENABLE_MAX_PREFILL = 1 + tensor_feature = paddle.randn([2, 3, 224, 224]) + multimodal_inputs = { + "video_features": [tensor_feature, tensor_feature], + } + req_dict = { + "request_id": "req1", + "multimodal_inputs": multimodal_inputs, + } + task = Request.from_dict(req_dict) + EngineWorkerQueue.to_numpy(([task], 1)) + + # 验证已转换为ndarray + self.assertEqual(len(task.multimodal_inputs["video_features"]), 2) + self.assertIsInstance(task.multimodal_inputs["video_features"][0], np.ndarray) + self.assertIsInstance(task.multimodal_inputs["video_features"][1], np.ndarray) + if __name__ == "__main__": unittest.main() diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py new file mode 100644 index 000000000..4534fb60d --- /dev/null +++ b/tests/v1/test_resource_manager_v1.py @@ -0,0 +1,162 @@ +import concurrent.futures +import pickle +import unittest +from dataclasses import asdict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np + +from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.request import Request +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + + +class TestResourceManagerV1(unittest.TestCase): + def setUp(self): + max_num_seqs = 2 + engine_args = EngineArgs( + max_num_seqs=max_num_seqs, + num_gpu_blocks_override=102, + max_num_batched_tokens=3200, + enable_async_download_features=True, + ) + args = asdict(engine_args) + + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + model_cfg.max_model_len = 5120 + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + scheduler_cfg = SchedulerConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + scheduler_config=scheduler_cfg, + ) + self.manager = ResourceManagerV1( + max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" + ) + req_dict = { + "request_id": "test_request", + "multimodal_inputs": {}, + } + self.request = Request.from_dict(req_dict) + self.request.async_process_futures = [] + self.request.multimodal_inputs = {} + + def test_waiting_async_process_no_futures(self): + """Test when there are no async process futures""" + result = self.manager._waiting_async_process(self.request) + self.assertFalse(result) + + def test_waiting_async_process_future_done_no_error(self): + """Test when future is done with no error""" + future = concurrent.futures.Future() + future.set_result(True) + self.request.async_process_futures = [future] + + result = self.manager._waiting_async_process(self.request) + self.assertFalse(result) + self.assertEqual(len(self.request.async_process_futures), 0) + + def test_waiting_async_process_future_done_with_error(self): + """Test when future is done with error""" + future = concurrent.futures.Future() + future.set_result(True) + self.request.async_process_futures = [future] + self.request.error_message = "Download failed" + + result = self.manager._waiting_async_process(self.request) + self.assertIsNone(result) + + def test_waiting_async_process_future_not_done(self): + """Test when future is not done""" + future = concurrent.futures.Future() + self.request.async_process_futures = [future] + + result = self.manager._waiting_async_process(self.request) + self.assertTrue(result) + self.assertEqual(len(self.request.async_process_futures), 1) + + def test_apply_async_preprocess(self): + """Test applying async preprocess""" + with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit: + mock_submit.return_value = "mock_future" + self.manager._apply_async_preprocess(self.request) + + mock_submit.assert_called_once_with(self.manager._download_features, self.request) + self.assertEqual(len(self.request.async_process_futures), 1) + self.assertEqual(self.request.async_process_futures[0], "mock_future") + + @patch("fastdeploy.utils.init_bos_client") + @patch("fastdeploy.utils.download_from_bos") + def test_download_features_no_features(self, mock_download, mock_init): + """Test when no features to download""" + self.request.multimodal_inputs = {} + result = self.manager._download_features(self.request) + self.assertIsNone(result) + mock_download.assert_not_called() + mock_init.assert_not_called() + + def test_download_features_video_success(self): + """Test successful video feature download""" + mock_client = MagicMock() + mock_client.get_object_as_string.return_value = pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32)) + + self.request.multimodal_inputs = {"video_feature_urls": ["bos://bucket-name/path/to/object1"]} + + self.manager.bos_client = mock_client + result = self.manager._download_features(self.request) + self.assertIsNone(result) + self.assertIn("video_features", self.request.multimodal_inputs) + self.assertIsInstance(self.request.multimodal_inputs["video_features"][0], np.ndarray) + + def test_download_features_image_error(self): + """Test image feature download with error""" + mock_client = MagicMock() + mock_client.get_object_as_string.side_effect = Exception("network error") + + self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]} + + self.manager.bos_client = mock_client + result = self.manager._download_features(self.request) + self.assertIsNone(result) + self.assertEqual( + self.request.error_message, + "request test_request download features error: link bucket-name/path/to/object1 download error: network error", + ) + self.assertEqual(self.request.error_code, 530) + + def test_download_features_audio_mixed(self): + """Test mixed success/error in audio feature download""" + mock_client = MagicMock() + mock_client.get_object_as_string.side_effect = [ + pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32)), + Exception("timeout"), + ] + + self.request.multimodal_inputs = { + "audio_feature_urls": ["bos://bucket-name/path/to/object1", "bos://bucket-name/path/to/object2"] + } + + self.manager.bos_client = mock_client + result = self.manager._download_features(self.request) + self.assertIsNone(result) + self.assertEqual( + self.request.error_message, + "request test_request download features error: link bucket-name/path/to/object2 download error: timeout", + ) + self.assertEqual(self.request.error_code, 530) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/v1/test_schedule_output.py b/tests/v1/test_schedule_output.py index d6002e73b..1fccb8979 100644 --- a/tests/v1/test_schedule_output.py +++ b/tests/v1/test_schedule_output.py @@ -55,7 +55,7 @@ def test_normal_schedule(): resource_manager_v1.add_request(req3) # step 1 assert len(resource_manager_v1.waiting) == 3 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 2 assert scheduler_reqs[0].request_id == "req1" assert scheduler_reqs[1].request_id == "req2" @@ -66,7 +66,7 @@ def test_normal_schedule(): assert len(resource_manager_v1.running) == 2 assert len(resource_manager_v1.waiting) == 1 # step 2 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 2 assert scheduler_reqs[0].request_id == "req1" assert len(scheduler_reqs[0].block_tables) == 52 @@ -76,7 +76,7 @@ def test_normal_schedule(): assert len(resource_manager_v1.running) == 2 assert len(resource_manager_v1.waiting) == 1 # step 3 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 2 assert scheduler_reqs[0].request_id == "req2" assert scheduler_reqs[0].prefill_start_index == 3200 @@ -118,7 +118,7 @@ def test_preempted_request(): resource_manager_v1.add_request(req2) # step 1 assert len(resource_manager_v1.waiting) == 2 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 1 assert scheduler_reqs[0].request_id == "req1" assert scheduler_reqs[0].prefill_start_index == 0 @@ -126,13 +126,13 @@ def test_preempted_request(): assert len(resource_manager_v1.running) == 1 assert len(resource_manager_v1.waiting) == 1 # step 2 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 2 assert scheduler_reqs[0].request_id == "req1" assert len(scheduler_reqs[0].block_tables) == 52 # step 3 req1.output_token_ids.extend([1] * 128) - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 2 assert scheduler_reqs[0].request_id == "req2" assert len(resource_manager_v1.running) == 1 @@ -142,7 +142,7 @@ def test_preempted_request(): # mock token_processor to add into waiting resource_manager_v1.waiting.appendleft(req2) # step 4 - scheduler_reqs = resource_manager_v1.schedule() + scheduler_reqs, _ = resource_manager_v1.schedule() assert len(scheduler_reqs) == 0 assert len(resource_manager_v1.running) == 1 assert len(resource_manager_v1.waiting) == 1