[Feature] support async download features (#5003)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support async download features

* add test case

* update code
This commit is contained in:
kevin
2025-11-19 22:23:36 +08:00
committed by GitHub
parent bde97e09f7
commit 109d48e456
10 changed files with 433 additions and 75 deletions

View File

@@ -550,6 +550,8 @@ class ParallelConfig:
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
# enable async download features
self.enable_async_download_features: bool = False
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).

View File

@@ -467,6 +467,11 @@ class EngineArgs:
Url for router server, such as `0.0.0.0:30000`.
"""
enable_async_download_features: bool = False
"""
Flag to enable async download features. Default is False (disabled).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -849,6 +854,12 @@ class EngineArgs:
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.",
)
parallel_group.add_argument(
"--enable-async-download-features",
action="store_true",
default=EngineArgs.enable_async_download_features,
help="Enable async download features.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")

View File

@@ -51,14 +51,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import (
EngineError,
check_download_links,
envs,
get_logger,
init_bos_client,
llm_logger,
)
from fastdeploy.utils import EngineError, envs, get_logger, llm_logger
try:
TokenProcessor = load_token_processor_plugins()
@@ -808,7 +801,7 @@ class EngineService:
else:
raise
# 2. Schedule requests
tasks = self.resource_manager.schedule()
tasks, error_tasks = self.resource_manager.schedule()
# 3. Send to engine
if tasks:
@@ -833,7 +826,16 @@ class EngineService:
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
# 4. Response error tasks
if error_tasks:
for request_id, failed in error_tasks:
if failed is None:
llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
continue
self._send_error_response(request_id, failed)
if not tasks and not error_tasks:
time.sleep(0.005)
except RuntimeError as e:
@@ -909,24 +911,6 @@ class EngineService:
self.llm_logger.error(f"Receive request error: {err_msg}")
results.append((request.request_id, err_msg))
if self._has_features_info(request) and err_msg is None:
if self.bos_client is None:
self.bos_client = init_bos_client()
download_urls = []
inputs = request.multimodal_inputs
if inputs.get("video_feature_urls") is not None:
download_urls.extend(inputs.get("video_feature_urls"))
if inputs.get("image_feature_urls") is not None:
download_urls.extend(inputs.get("image_feature_urls"))
if inputs.get("audio_feature_urls") is not None:
download_urls.extend(inputs.get("audio_feature_urls"))
err_msg = check_download_links(self.bos_client, download_urls)
if err_msg:
llm_logger.error(f"Receive request {request.request_id} download error: {err_msg}")
results.append((request.request_id, err_msg))
if err_msg is None:
insert_task.append(request)
@@ -948,21 +932,27 @@ class EngineService:
main_process_metrics.num_requests_waiting.inc(1)
continue
error_result = RequestOutput(
request_id=request_id,
finished=True,
error_code=500,
error_msg=failed,
)
# Since the request is not in scheduler
# Send result by zmq directly
self.send_response_server.send_response(request_id, [error_result])
self._send_error_response(request_id, failed)
except Exception as e:
self.llm_logger.error(
f"Error happened while receiving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}"
)
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
)
error_result = RequestOutput(
request_id=request_id,
finished=True,
error_code=error_code,
error_msg=error_msg,
)
# Since the request is not in scheduler
# Send result by zmq directly
self.send_response_server.send_response(request_id, [error_result])
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
if envs.FD_ENABLE_RETURN_TEXT:
@@ -977,19 +967,6 @@ class EngineService:
del self.data_processor.decode_status[req_id]
return delta_text, token_ids
def _has_features_info(self, task):
inputs = task.multimodal_inputs
if inputs is None or len(inputs) == 0:
return False
if (
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
):
return True
return False
def _zmq_send_generated_tokens(self):
"""
Recieve output for zmq

View File

@@ -173,6 +173,10 @@ class Request:
# dp
self.dp_rank = dp_rank
self.async_process_futures = []
self.error_message = None
self.error_code = None
@classmethod
def from_dict(cls, d: dict):
data_processor_logger.debug(f"{d}")

View File

@@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger
@dataclass
@@ -195,6 +195,9 @@ class ResourceManagerV1(ResourceManager):
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -500,6 +503,7 @@ class ResourceManagerV1(ResourceManager):
with self.lock:
scheduled_reqs: list[Request] = []
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
# First, schedule the RUNNING requests.
@@ -629,6 +633,7 @@ class ResourceManagerV1(ResourceManager):
req_index += 1
# schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_seqs:
break
@@ -639,6 +644,17 @@ class ResourceManagerV1(ResourceManager):
):
break
if request.status == RequestStatus.WAITING:
result = self._waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
continue
elif result is True:
# skip current request, try next request
skip_requests.append(request)
self.waiting.popleft()
continue
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
@@ -725,12 +741,102 @@ class ResourceManagerV1(ResourceManager):
else:
llm_logger.error("Unknown request status type")
for req in skip_requests:
# move waiting request to end of the deque
self.waiting.append(req)
if scheduled_reqs:
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
self.update_metrics()
return scheduled_reqs
return scheduled_reqs, error_reqs
def _waiting_async_process(self, request: Request) -> None:
"""
Check if async preprocessing is complete for a request.
Args:
request: The request to check
Returns:
None: If an error occurred during preprocessing
True: If preprocessing is still in progress (request should be skipped)
False: If preprocessing is complete (request can be scheduled)
"""
for future in request.async_process_futures:
if future.done():
if request.get("error_message") is not None:
return None
else:
return True
request.async_process_futures = []
return False
def _apply_async_preprocess(self, request: Request) -> None:
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))
def _has_features_info(self, task):
inputs = task.multimodal_inputs
if inputs is None or len(inputs) == 0:
return False
if (
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
):
return True
return False
def _download_features(self, request: Request) -> None:
"""
download multimodal features from bos
Note:
1. this function will be add features for request.multimodal_inputs
2. this function maybe update request.error_message and request.error_code
Args:
request (Request): request object
"""
def download_bos_features(bos_client, features_urls):
result_list = []
for status, feature in download_from_bos(self.bos_client, features_urls):
if status:
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
result_list.append(feature)
else:
error_msg = f"request {request.request_id} download features error: {feature}"
llm_logger.error(error_msg)
return error_msg
return result_list
if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
return None
if self.bos_client is None:
self.bos_client = init_bos_client()
inputs = request.multimodal_inputs
if inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["video_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["video_features"] = result
if inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["image_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["image_features"] = result
if inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["audio_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["audio_features"] = result
def get_available_position(self) -> int:
position = 0
@@ -788,6 +894,7 @@ class ResourceManagerV1(ResourceManager):
def add_request(self, request: Request) -> None:
with self.lock:
self._apply_async_preprocess(request)
self.waiting.append(request)
self.requests[request.request_id] = request

View File

@@ -499,7 +499,13 @@ class EngineWorkerQueue:
"attention_mask_offset",
]
llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys}")
list_keys = [
"image_features",
"video_features",
"audio_features",
]
llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys + list_keys}")
for key in tensor_keys:
value = multimodal_inputs.get(key)
@@ -507,6 +513,13 @@ class EngineWorkerQueue:
continue
if not isinstance(value, paddle.Tensor):
multimodal_inputs[key] = paddle.to_tensor(value)
for key in list_keys:
value = multimodal_inputs.get(key)
if value is None:
continue
if isinstance(value, list):
multimodal_inputs[key] = [paddle.to_tensor(v) for v in value]
except Exception as e:
llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}")
@@ -518,16 +531,30 @@ class EngineWorkerQueue:
Args:
tasks: List of tasks containing multimodal inputs.
"""
if (not envs.FD_ENABLE_MAX_PREFILL) and (not envs.FD_ENABLE_E2W_TENSOR_CONVERT):
return
try:
if envs.FD_ENABLE_MAX_PREFILL:
for batch_tasks, _ in tasks:
for task in batch_tasks:
if not hasattr(task, "multimodal_inputs"):
continue
images = task.multimodal_inputs.get("images", None)
if isinstance(images, paddle.Tensor):
llm_logger.debug(f"Convert image to numpy, shape: {images.shape}")
task.multimodal_inputs["images"] = images.numpy()
batch_tasks, _ = tasks
for task in batch_tasks:
if not hasattr(task, "multimodal_inputs"):
continue
images = task.multimodal_inputs.get("images", None)
if isinstance(images, paddle.Tensor):
llm_logger.debug(f"Convert image to numpy, shape: {images.shape}")
task.multimodal_inputs["images"] = images.numpy()
list_keys = [
"image_features",
"video_features",
"audio_features",
]
for key in list_keys:
value = task.multimodal_inputs.get(key, None)
if value is None:
continue
if isinstance(value, list):
task.multimodal_inputs[key] = [v.numpy() for v in value]
except Exception as e:
llm_logger.warning(f"Failed to convert to numpy: {e}")
@@ -565,7 +592,7 @@ class EngineWorkerQueue:
tasks.extend(self.tasks)
# 多模态输入转换为numpy
# EngineWorkerQueue.to_numpy(tasks)
EngineWorkerQueue.to_numpy(tasks)
self.client_read_flag[self.client_id] = 1
all_client_read: bool = np.sum(self.client_read_flag) == self.num_client

View File

@@ -21,6 +21,7 @@ import importlib
import json
import logging
import os
import pickle
import random
import re
import socket
@@ -975,6 +976,36 @@ def init_bos_client():
return BosClient(cfg)
def download_from_bos(bos_client, bos_links):
"""
Download pickled objects from Baidu Object Storage (BOS).
Args:
bos_client: BOS client instance
bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object"
Yields:
tuple: (success: bool, data: np.ndarray | error_msg: str)
- On success: (True, deserialized_data)
- On failure: (False, error_message) and stops processing remaining links
Security Note:
Uses pickle deserialization. Only use with trusted data sources.
"""
if not isinstance(bos_links, list):
bos_links = [bos_links]
for link in bos_links:
try:
if link.startswith("bos://"):
link = link.replace("bos://", "")
bucket_name = "/".join(link.split("/")[1:-1])
object_key = link.split("/")[-1]
response = bos_client.get_object_as_string(bucket_name, object_key)
yield True, pickle.loads(response)
except Exception as e:
yield False, f"link {link} download error: {str(e)}"
break
llm_logger = get_logger("fastdeploy", "fastdeploy.log")
data_processor_logger = get_logger("data_processor", "data_processor.log")
scheduler_logger = get_logger("scheduler", "scheduler.log")

View File

@@ -20,6 +20,7 @@ import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator.engine_worker_queue import EngineWorkerQueue
@@ -80,7 +81,7 @@ class TestEngineWorkerQueue(unittest.TestCase):
# 构造 paddle.Tensor 输入
tensor_images = paddle.randn([2, 3, 224, 224])
task = DummyTask(tensor_images)
tasks = [([task], 1)]
tasks = ([task], 1)
EngineWorkerQueue.to_numpy(tasks)
@@ -94,7 +95,7 @@ class TestEngineWorkerQueue(unittest.TestCase):
tensor_images = paddle.randn([2, 3, 224, 224])
# 创建模拟任务
task = DummyTask(tensor_images)
tasks = [([task], 1)]
tasks = ([task], 1)
# 调用转换方法(预期不会转换)
EngineWorkerQueue.to_numpy(tasks)
@@ -107,7 +108,7 @@ class TestEngineWorkerQueue(unittest.TestCase):
pass
task = NoMMTask()
tasks = [([task], 1)]
tasks = ([task], 1)
# 不应抛异常
try:
@@ -119,7 +120,7 @@ class TestEngineWorkerQueue(unittest.TestCase):
envs.FD_ENABLE_MAX_PREFILL = 1
np_images = np.random.randn(2, 3, 224, 224)
task = DummyTask(np_images)
tasks = [([task], 1)]
tasks = ([task], 1)
EngineWorkerQueue.to_numpy(tasks)
@@ -135,13 +136,49 @@ class TestEngineWorkerQueue(unittest.TestCase):
raise RuntimeError("mock error")
bad_task = DummyTask(images=BadTensor())
bad_tasks = [([bad_task], 1)]
bad_tasks = ([bad_task], 1)
try:
EngineWorkerQueue.to_numpy(bad_tasks)
except Exception as e:
self.fail(f"Exception should be handled internally, but got: {e}")
def test_features_info_to_tensor(self):
envs.FD_ENABLE_MAX_PREFILL = 1
np_feature = paddle.randn([2, 3, 224, 224]).numpy()
multimodal_inputs = {
"image_features": [np_feature, np_feature],
}
req_dict = {
"request_id": "req1",
"multimodal_inputs": multimodal_inputs,
}
task = Request.from_dict(req_dict)
EngineWorkerQueue.to_tensor(([task], 1))
# 验证已转换为tensor
self.assertEqual(len(task.multimodal_inputs["image_features"]), 2)
self.assertIsInstance(task.multimodal_inputs["image_features"][0], paddle.Tensor)
self.assertIsInstance(task.multimodal_inputs["image_features"][1], paddle.Tensor)
def test_features_info_to_numpy(self):
envs.FD_ENABLE_MAX_PREFILL = 1
tensor_feature = paddle.randn([2, 3, 224, 224])
multimodal_inputs = {
"video_features": [tensor_feature, tensor_feature],
}
req_dict = {
"request_id": "req1",
"multimodal_inputs": multimodal_inputs,
}
task = Request.from_dict(req_dict)
EngineWorkerQueue.to_numpy(([task], 1))
# 验证已转换为ndarray
self.assertEqual(len(task.multimodal_inputs["video_features"]), 2)
self.assertIsInstance(task.multimodal_inputs["video_features"][0], np.ndarray)
self.assertIsInstance(task.multimodal_inputs["video_features"][1], np.ndarray)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,162 @@
import concurrent.futures
import pickle
import unittest
from dataclasses import asdict
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy as np
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.request import Request
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
class TestResourceManagerV1(unittest.TestCase):
def setUp(self):
max_num_seqs = 2
engine_args = EngineArgs(
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=102,
max_num_batched_tokens=3200,
enable_async_download_features=True,
)
args = asdict(engine_args)
cache_cfg = CacheConfig(args)
model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
model_cfg.max_model_len = 5120
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)
graph_opt_cfg = engine_args.create_graph_optimization_config()
fd_config = FDConfig(
model_config=model_cfg,
cache_config=cache_cfg,
parallel_config=parallel_cfg,
graph_opt_config=graph_opt_cfg,
speculative_config=speculative_cfg,
scheduler_config=scheduler_cfg,
)
self.manager = ResourceManagerV1(
max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed"
)
req_dict = {
"request_id": "test_request",
"multimodal_inputs": {},
}
self.request = Request.from_dict(req_dict)
self.request.async_process_futures = []
self.request.multimodal_inputs = {}
def test_waiting_async_process_no_futures(self):
"""Test when there are no async process futures"""
result = self.manager._waiting_async_process(self.request)
self.assertFalse(result)
def test_waiting_async_process_future_done_no_error(self):
"""Test when future is done with no error"""
future = concurrent.futures.Future()
future.set_result(True)
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
self.assertFalse(result)
self.assertEqual(len(self.request.async_process_futures), 0)
def test_waiting_async_process_future_done_with_error(self):
"""Test when future is done with error"""
future = concurrent.futures.Future()
future.set_result(True)
self.request.async_process_futures = [future]
self.request.error_message = "Download failed"
result = self.manager._waiting_async_process(self.request)
self.assertIsNone(result)
def test_waiting_async_process_future_not_done(self):
"""Test when future is not done"""
future = concurrent.futures.Future()
self.request.async_process_futures = [future]
result = self.manager._waiting_async_process(self.request)
self.assertTrue(result)
self.assertEqual(len(self.request.async_process_futures), 1)
def test_apply_async_preprocess(self):
"""Test applying async preprocess"""
with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit:
mock_submit.return_value = "mock_future"
self.manager._apply_async_preprocess(self.request)
mock_submit.assert_called_once_with(self.manager._download_features, self.request)
self.assertEqual(len(self.request.async_process_futures), 1)
self.assertEqual(self.request.async_process_futures[0], "mock_future")
@patch("fastdeploy.utils.init_bos_client")
@patch("fastdeploy.utils.download_from_bos")
def test_download_features_no_features(self, mock_download, mock_init):
"""Test when no features to download"""
self.request.multimodal_inputs = {}
result = self.manager._download_features(self.request)
self.assertIsNone(result)
mock_download.assert_not_called()
mock_init.assert_not_called()
def test_download_features_video_success(self):
"""Test successful video feature download"""
mock_client = MagicMock()
mock_client.get_object_as_string.return_value = pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32))
self.request.multimodal_inputs = {"video_feature_urls": ["bos://bucket-name/path/to/object1"]}
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertIn("video_features", self.request.multimodal_inputs)
self.assertIsInstance(self.request.multimodal_inputs["video_features"][0], np.ndarray)
def test_download_features_image_error(self):
"""Test image feature download with error"""
mock_client = MagicMock()
mock_client.get_object_as_string.side_effect = Exception("network error")
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object1 download error: network error",
)
self.assertEqual(self.request.error_code, 530)
def test_download_features_audio_mixed(self):
"""Test mixed success/error in audio feature download"""
mock_client = MagicMock()
mock_client.get_object_as_string.side_effect = [
pickle.dumps(np.array([[1, 2, 3]], dtype=np.float32)),
Exception("timeout"),
]
self.request.multimodal_inputs = {
"audio_feature_urls": ["bos://bucket-name/path/to/object1", "bos://bucket-name/path/to/object2"]
}
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object2 download error: timeout",
)
self.assertEqual(self.request.error_code, 530)
if __name__ == "__main__":
unittest.main()

View File

@@ -55,7 +55,7 @@ def test_normal_schedule():
resource_manager_v1.add_request(req3)
# step 1
assert len(resource_manager_v1.waiting) == 3
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req1"
assert scheduler_reqs[1].request_id == "req2"
@@ -66,7 +66,7 @@ def test_normal_schedule():
assert len(resource_manager_v1.running) == 2
assert len(resource_manager_v1.waiting) == 1
# step 2
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req1"
assert len(scheduler_reqs[0].block_tables) == 52
@@ -76,7 +76,7 @@ def test_normal_schedule():
assert len(resource_manager_v1.running) == 2
assert len(resource_manager_v1.waiting) == 1
# step 3
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req2"
assert scheduler_reqs[0].prefill_start_index == 3200
@@ -118,7 +118,7 @@ def test_preempted_request():
resource_manager_v1.add_request(req2)
# step 1
assert len(resource_manager_v1.waiting) == 2
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 1
assert scheduler_reqs[0].request_id == "req1"
assert scheduler_reqs[0].prefill_start_index == 0
@@ -126,13 +126,13 @@ def test_preempted_request():
assert len(resource_manager_v1.running) == 1
assert len(resource_manager_v1.waiting) == 1
# step 2
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req1"
assert len(scheduler_reqs[0].block_tables) == 52
# step 3
req1.output_token_ids.extend([1] * 128)
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 2
assert scheduler_reqs[0].request_id == "req2"
assert len(resource_manager_v1.running) == 1
@@ -142,7 +142,7 @@ def test_preempted_request():
# mock token_processor to add into waiting
resource_manager_v1.waiting.appendleft(req2)
# step 4
scheduler_reqs = resource_manager_v1.schedule()
scheduler_reqs, _ = resource_manager_v1.schedule()
assert len(scheduler_reqs) == 0
assert len(resource_manager_v1.running) == 1
assert len(resource_manager_v1.waiting) == 1