Files
FastDeploy/tests/scheduler/test_splitwise_scheduler.py
xunyoyo 9e8c46c526
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
[CI] 【Hackathon 9th Sprint No.34】NO.34 功能模块单测补充 (#5057)
* Add unit tests for SplitWiseScheduler module

* Add info and ping to fake redis client for tests

* Document fake redis metadata methods in tests

* Enhance splitwise scheduler tests

* Clean up test_splitwise_scheduler.py

Removed copyright notice and documentation comments.

* Simplify splitwise scheduler test stubs

* Refine splitwise scheduler tests

* Handle empty result keys with restored sleep

---------

Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
2025-12-15 20:29:25 +08:00

1307 lines
48 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import importlib
import pickle
import random
import sys
import time
import types
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
def _install_stub_modules() -> None:
"""Install lightweight stand-ins for the external dependencies."""
if getattr(_install_stub_modules, "_installed", False):
return
# --------------------------------------------------------------- Redis stubs
class _FakePipeline:
def __init__(self, client: "_FakeRedis") -> None:
self._client = client
self._commands: list[tuple[str, tuple[Any, ...]]] = []
def __enter__(self) -> "_FakePipeline":
return self
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
return None
def multi(self) -> "_FakePipeline":
return self
def lpush(self, key: str, *values: Any) -> "_FakePipeline":
self._commands.append(("lpush", (key, values)))
return self
def expire(self, key: str, ttl: int) -> "_FakePipeline":
self._commands.append(("expire", (key, ttl)))
return self
def execute(self) -> None:
for name, params in self._commands:
if name == "lpush":
key, values = params
self._client.lpush(key, *values)
elif name == "expire":
key, ttl = params
self._client.expire(key, ttl)
self._commands.clear()
class _FakeRedis:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.storage: dict[str, list[Any]] = {}
self.hashes: dict[str, dict[Any, Any]] = {}
self.expirations: dict[str, int] = {}
# ------------------------------- list operations used by the scheduler
def lpush(self, key: str, *values: Any) -> None:
items = list(values)
if not items:
return
bucket = self.storage.setdefault(key, [])
for value in items:
bucket.insert(0, value)
def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]:
bucket = self.storage.get(key)
if not bucket:
return None
if count is None:
return [bucket.pop()]
count = min(count, len(bucket))
values = [bucket.pop() for _ in range(count)]
return values
def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override]
for key in keys:
bucket = self.storage.get(key)
if bucket:
return (key, bucket.pop())
return None
# ------------------------------------------ hash operations for cluster
def hset(self, key: str, field: str, value: Any) -> None:
self.hashes.setdefault(key, {})[field] = value
def hgetall(self, key: str) -> dict[Any, Any]:
return {k: v for k, v in self.hashes.get(key, {}).items()}
def hdel(self, key: str, field: str) -> None:
if key in self.hashes:
self.hashes[key].pop(field, None)
# -------------------------------------------------------------- misc ops
def expire(self, key: str, ttl: int) -> None:
self.expirations[key] = ttl
def pipeline(self) -> _FakePipeline:
return _FakePipeline(self)
# Metadata required by InferScheduler.check_redis_version
def info(self) -> dict[str, str]:
return {"redis_version": "6.2.0"}
# Health check used by InferScheduler.start
def ping(self) -> bool:
return True
redis_mod = types.ModuleType("redis")
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
sys.modules.setdefault("redis", redis_mod)
# ------------------------------------------- fastdeploy.engine.request stub
request_mod = types.ModuleType("fastdeploy.engine.request")
@dataclass
class CompletionOutput:
index: int
send_idx: int
token_ids: List[int]
finished: bool = False
def to_dict(self) -> Dict[str, Any]:
return {
"index": self.index,
"send_idx": self.send_idx,
"token_ids": list(self.token_ids),
"finished": self.finished,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput":
return cls(
index=data.get("index", 0),
send_idx=data.get("send_idx", 0),
token_ids=list(data.get("token_ids", [])),
finished=data.get("finished", False),
)
@dataclass
class RequestMetrics:
arrival_time: float
inference_start_time: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
return {
"arrival_time": self.arrival_time,
"inference_start_time": self.inference_start_time,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics":
return cls(
arrival_time=data.get("arrival_time", time.time()),
inference_start_time=data.get("inference_start_time"),
)
class Request:
def __init__(
self,
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
prompt_token_ids_len: int = 0,
arrival_time: Optional[float] = None,
disaggregate_info: Optional[Dict[str, Any]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt or ""
self.prompt_token_ids = prompt_token_ids or []
self.prompt_token_ids_len = prompt_token_ids_len
self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.metrics = RequestMetrics(arrival_time=self.arrival_time)
self.disaggregate_info = disaggregate_info
def to_dict(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": list(self.prompt_token_ids),
"prompt_token_ids_len": self.prompt_token_ids_len,
"arrival_time": self.arrival_time,
"metrics": self.metrics.to_dict(),
"disaggregate_info": self.disaggregate_info,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Request":
req = cls(
request_id=data["request_id"],
prompt=data.get("prompt"),
prompt_token_ids=data.get("prompt_token_ids"),
prompt_token_ids_len=data.get("prompt_token_ids_len", 0),
arrival_time=data.get("arrival_time", time.time()),
disaggregate_info=data.get("disaggregate_info"),
)
metrics_dict = data.get("metrics")
if metrics_dict:
req.metrics = RequestMetrics.from_dict(metrics_dict)
else:
req.refresh_metrics()
return req
def refresh_metrics(self) -> None:
self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time})
class RequestOutput:
def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
outputs: CompletionOutput,
metrics: RequestMetrics,
finished: bool = False,
error_code: int = 200,
error_msg: Optional[str] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.metrics = metrics
self.finished = finished
self.error_code = error_code
self.error_msg = error_msg
def to_dict(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": list(self.prompt_token_ids),
"outputs": self.outputs.to_dict(),
"metrics": self.metrics.to_dict(),
"finished": self.finished,
"error_code": self.error_code,
"error_msg": self.error_msg,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput":
return cls(
request_id=data["request_id"],
prompt=data.get("prompt", ""),
prompt_token_ids=list(data.get("prompt_token_ids", [])),
outputs=CompletionOutput.from_dict(data.get("outputs", {})),
metrics=RequestMetrics.from_dict(data.get("metrics", {})),
finished=data.get("finished", False),
error_code=data.get("error_code", 200),
error_msg=data.get("error_msg"),
)
request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined]
request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined]
request_mod.Request = Request # type: ignore[attr-defined]
request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined]
sys.modules["fastdeploy.engine.request"] = request_mod
fd_pkg = types.ModuleType("fastdeploy")
fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")]
sys.modules["fastdeploy"] = fd_pkg
scheduler_pkg = types.ModuleType("fastdeploy.scheduler")
scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")]
sys.modules["fastdeploy.scheduler"] = scheduler_pkg
logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger")
def _log(*_args: Any, **_kwargs: Any) -> None:
return None
for level in ("info", "error", "debug", "warning"):
setattr(logger_mod, level, _log) # type: ignore[attr-defined]
sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod
utils_mod = types.ModuleType("fastdeploy.utils")
utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined]
sys.modules["fastdeploy.utils"] = utils_mod
_install_stub_modules._installed = True
def _import_splitwise_scheduler():
"""Import the scheduler module with the stub environment."""
_install_stub_modules()
return importlib.import_module("fastdeploy.scheduler.splitwise_scheduler")
class _PatchedThread:
def __init__(self, *args: Any, target=None, **kwargs: Any) -> None: # type: ignore[override]
self._target = target
self.started = False
def start(self) -> None:
self.started = True
class _Writer:
def __init__(self) -> None:
self.items: list[tuple[str, list[bytes]]] = []
def put(self, key: str, items: list[bytes]) -> None:
self.items.append((key, items))
class SplitWiseSchedulerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.module = _import_splitwise_scheduler()
self._orig_thread = self.module.threading.Thread
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
def tearDown(self) -> None:
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
class SplitWiseSchedulerConfigTest(SplitWiseSchedulerTestCase):
def test_threshold_defaults_to_model_ratio(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=5,
max_long_partial_prefills=3,
max_model_len=1000,
)
self.assertEqual(config.long_prefill_token_threshold, 40)
self.assertEqual(config.expire_period, 3.0)
def test_check_and_print_cover_logging(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=50,
)
config.check()
config.print()
class NodeInfoTest(SplitWiseSchedulerTestCase):
def test_serialization_and_expiration(self) -> None:
node = self.module.NodeInfo(
nodeid="node-1",
role="prefill",
host="localhost",
disaggregated={"transfer_protocol": ["ipc", "rdma"]},
load=2,
)
payload = node.serialize()
loaded = self.module.NodeInfo.load_from("node-1", payload)
self.assertFalse(loaded.expired(10))
loaded.ts -= 20
self.assertTrue(loaded.expired(1))
loaded.add_req("req-1", 4)
self.assertIn("req-1", loaded.reqs)
loaded.update_req_timestamp(["req-1"])
before = loaded.reqs["req-1"][1]
loaded.reqs["req-1"][1] -= 1000
loaded.expire_reqs(ttl=1)
self.assertNotIn("req-1", loaded.reqs)
loaded.add_req("req-2", 2)
loaded.finish_req("req-2")
self.assertNotIn("req-2", loaded.reqs)
self.assertNotEqual(before, loaded.ts)
def test_comparisons(self) -> None:
low = self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
high = self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5)
self.assertTrue(low < high)
self.assertIn("a(1)", repr(low))
class ResultReaderTest(SplitWiseSchedulerTestCase):
def test_read_handles_group_tokens_with_buffer_and_outputs(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
req = self.module.Request("req-buffer", prompt_token_ids_len=2)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
head = self.module.RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=False,
)
buffered = self.module.RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]),
metrics=metrics,
finished=False,
)
trailing = self.module.RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]),
metrics=metrics,
finished=True,
)
with reader.lock:
reader.out_buffer[req.request_id] = [buffered]
reader.data.appendleft(head)
reader.data.appendleft(trailing)
outputs = reader.read()
self.assertIn(req.request_id, outputs)
# Triggers the path where group_tokens has no pre-existing output bucket
# so the branch at lines 353-354 is exercised.
another = self.module.RequestOutput(
request_id="req-new",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]),
metrics=metrics,
finished=True,
)
reader.data.appendleft(another)
outputs = reader.read()
self.assertEqual(outputs["req-new"][0].outputs.token_ids, [9])
def test_read_groups_partial_outputs(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
req = self.module.Request("req-A", prompt_token_ids_len=3)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
first = self.module.RequestOutput(
request_id="req-A",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
metrics=metrics,
finished=False,
)
follow = self.module.RequestOutput(
request_id="req-A",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]),
metrics=metrics,
finished=True,
)
reader.data.appendleft(follow)
reader.data.appendleft(first)
outputs = reader.read()
self.assertIn("req-A", outputs)
self.assertEqual(len(outputs["req-A"]), 2)
def test_sync_results_converts_payloads(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
metrics = self.module.RequestMetrics(arrival_time=time.time())
ro = self.module.RequestOutput(
request_id="req-B",
prompt="p",
prompt_token_ids=[1],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]),
metrics=metrics,
finished=True,
)
payload = self.module.orjson.dumps(ro.to_dict())
client.storage.setdefault("req-key", []).append(payload)
total = reader.sync_results(["req-key"])
self.assertEqual(total, 1)
self.assertTrue(reader.data)
def test_read_uses_out_buffer(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
req = self.module.Request("req-out", prompt_token_ids_len=2)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
head = self.module.RequestOutput(
request_id="req-out",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=False,
)
tail = self.module.RequestOutput(
request_id="req-out",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
metrics=metrics,
finished=True,
)
with reader.lock:
reader.out_buffer[req.request_id] = [tail]
reader.data.appendleft(head)
outputs = reader.read()
self.assertEqual(len(outputs["req-out"]), 2)
def test_sync_results_with_group_override(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
metrics = self.module.RequestMetrics(arrival_time=time.time())
ro = self.module.RequestOutput(
request_id="req-group",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]),
metrics=metrics,
finished=True,
)
payload = self.module.orjson.dumps(ro.to_dict())
client.storage.setdefault("grp", []).append(payload)
total = reader.sync_results(["unused"])
self.assertEqual(total, 1)
self.assertEqual(reader.data[-1].request_id, "req-group")
def test_run_emits_expired_placeholder(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=1, group="")
reader.reqs["old"] = {"arrival_time": time.time() - 5}
reader.reqs["active"] = {"arrival_time": time.time()}
call_count = {"rpop": 0}
def _rpop(key: str, batch: int):
call_count["rpop"] += 1
if call_count["rpop"] > 1:
raise SystemExit()
return []
reader.client.rpop = _rpop # type: ignore[assignment]
with self.assertRaises(SystemExit):
reader.run()
self.assertNotIn("old", reader.reqs)
self.assertTrue(reader.data)
self.assertGreaterEqual(call_count["rpop"], 1)
def test_run_handles_empty_keys_and_exceptions(self) -> None:
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=5, ttl=10, group="")
reader.reqs.clear()
original_sleep = self.module.time.sleep
try:
self.module.time.sleep = lambda _t: (_ for _ in ()).throw(SystemExit())
with self.assertRaises(SystemExit):
reader.run()
finally:
self.module.time.sleep = original_sleep
# Now cover the exception logging path inside run()
reader.reqs["rid"] = {"arrival_time": time.time()}
calls = {"count": 0}
def _rpop(_key: str, _batch: int):
calls["count"] += 1
if calls["count"] == 1:
raise ValueError("boom")
raise SystemExit()
reader.client.rpop = _rpop # type: ignore[assignment]
with self.assertRaises(SystemExit):
reader.run()
self.assertGreaterEqual(calls["count"], 2)
class APISchedulerTest(SplitWiseSchedulerTestCase):
def _make_config(self) -> Any:
return self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=5,
max_long_partial_prefills=3,
max_model_len=200,
)
def test_schedule_mixed_node_uses_single_queue(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-1", prompt_token_ids_len=10)
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
scheduler.schedule(req, [mixed], [], [], group="g0")
key = f"ReqQ_{mixed.nodeid}"
self.assertIn(key, scheduler.client.storage)
stored = scheduler.client.storage[key][0]
decoded = pickle.loads(stored)
self.assertEqual(decoded["group"], "g0")
self.assertIsNone(decoded["disaggregate_info"])
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-meta", prompt_token_ids_len=10)
disagg = {
"host_ip": "1.1.1.1",
"transfer_protocol": ["ipc"],
"connector_port": 10,
"device_ids": [0],
"rdma_ports": [100],
"tp_size": 2,
}
pre = self.module.NodeInfo("p", "prefill", "host-a", disagg, load=1)
dec = self.module.NodeInfo(
"d",
"decode",
"host-b",
{
"host_ip": "1.1.1.1",
"transfer_protocol": ["ipc"],
"connector_port": 11,
"device_ids": [1],
"rdma_ports": [101],
"tp_size": 2,
},
load=2,
)
scheduler.schedule(req, [pre], [dec], [], group="g1")
self.assertIsNotNone(req.disaggregate_info)
self.assertEqual(req.disaggregate_info["transfer_protocol"], "ipc")
self.assertIn(f"ReqQ_{pre.nodeid}", scheduler.client.storage)
self.assertIn(f"ReqQ_{dec.nodeid}", scheduler.client.storage)
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1))
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
with self.assertRaises(SystemExit):
scheduler.loop_schedule()
def test_sync_cluster_partitions_and_filters_nodes(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
now = time.time()
cluster_key = scheduler.cluster_key
expired_payload = self.module.orjson.dumps(
{"ts": now - 10, "role": "prefill", "load": 1, "host": "h", "disaggregated": {}}
)
scheduler.client.hset(cluster_key, b"expired", expired_payload)
valid_prefill = self.module.NodeInfo("p1", "prefill", "h1", {"transfer_protocol": ["ipc"]}, load=1)
scheduler.client.hset(cluster_key, b"p1", valid_prefill.serialize())
valid_decode = self.module.NodeInfo("d1", "decode", "h2", {"transfer_protocol": ["ipc"]}, load=2)
scheduler.client.hset(cluster_key, b"d1", valid_decode.serialize())
invalid_payload = self.module.orjson.dumps(
{"ts": now, "role": "unknown", "load": 0, "host": "h3", "disaggregated": {}}
)
scheduler.client.hset(cluster_key, b"bad", invalid_payload)
pnodes, dnodes, mnodes = scheduler.sync_cluster()
self.assertEqual(len(pnodes), 1)
self.assertEqual(len(dnodes), 1)
self.assertEqual(len(mnodes), 0)
def test_loop_clear_expired_nodes_removes_entries(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
cluster_key = scheduler.cluster_key
stale_payload = self.module.orjson.dumps(
{
"ts": time.time() - (scheduler.clear_expired_nodes_period + 5),
"role": "prefill",
"load": 0,
"host": "h",
"disaggregated": {"transfer_protocol": ["ipc"]},
}
)
scheduler.client.hset(cluster_key, b"old", stale_payload)
original_sleep = self.module.time.sleep
self.module.time.sleep = lambda _t: (_ for _ in ()).throw(SystemExit())
with self.assertRaises(SystemExit):
scheduler.loop_clear_expired_nodes()
self.module.time.sleep = original_sleep
self.assertNotIn(b"old", scheduler.client.hgetall(cluster_key))
def test_select_pd_paths(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-sel", prompt_token_ids_len=50)
nodes = [
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
]
random.seed(0)
chosen = scheduler.select_pd(req, nodes, "prefill")
self.assertIn(chosen, nodes)
decode_nodes = [
self.module.NodeInfo(str(i), "decode", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(2)
]
chosen_decode = scheduler.select_pd(req, decode_nodes, "decode")
self.assertIn(chosen_decode, decode_nodes)
def test_schedule_disaggregated_updates_protocol(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-2", prompt_token_ids_len=10)
prefill = self.module.NodeInfo(
"prefill",
"prefill",
"host-a",
{
"transfer_protocol": ["ipc"],
"host_ip": "1.1.1.1",
"connector_port": 10,
"device_ids": [0],
"rdma_ports": [1],
"tp_size": 1,
},
load=1,
)
decode = self.module.NodeInfo(
"decode",
"decode",
"host-b",
{
"transfer_protocol": ["ipc", "rdma"],
"host_ip": "2.2.2.2",
"connector_port": 11,
"device_ids": [1],
"rdma_ports": [2],
"tp_size": 1,
},
load=1,
)
def _select(req_obj, nodes, role):
return nodes[0]
scheduler.select_pd = _select # type: ignore[assignment]
scheduler.schedule(req, [prefill], [decode], [], group="")
self.assertIn("ReqQ_prefill", scheduler.client.storage)
self.assertIn("ReqQ_decode", scheduler.client.storage)
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma")
def test_sync_cluster_filters_expired_nodes(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
stale_payload = self.module.orjson.dumps(
{
"ts": time.time() - (config.expire_period + 1),
"role": "prefill",
"load": 1,
"host": "h",
"disaggregated": {"transfer_protocol": ["ipc"]},
}
)
scheduler.client.hset(scheduler.cluster_key, b"n2", stale_payload)
pnodes, _, _ = scheduler.sync_cluster()
self.assertEqual([node.nodeid for node in pnodes], ["n1"])
def test_start_put_and_get_results(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
scheduler.start()
reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)]
result = scheduler.put_requests(reqs)
self.assertEqual(len(result), 2)
fake_output = {"a": ["value"]}
scheduler.readers = [types.SimpleNamespace(read=lambda: fake_output)]
outputs = scheduler.get_results()
self.assertEqual(outputs, fake_output)
def test_select_pd_prefill_and_decode(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-select", prompt_token_ids_len=50)
prefill_nodes = [
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
]
decode_nodes = [
self.module.NodeInfo("c", "decode", "h", {"transfer_protocol": ["ipc"]}, load=1),
self.module.NodeInfo("d", "decode", "h", {"transfer_protocol": ["ipc"]}, load=2),
]
original_choice = self.module.random.choice
self.module.random.choice = lambda seq: seq[-1] # type: ignore[assignment]
try:
picked_prefill = scheduler.select_pd(req, prefill_nodes, "prefill")
picked_decode = scheduler.select_pd(req, decode_nodes, "decode")
finally:
self.module.random.choice = original_choice
self.assertEqual(picked_prefill.nodeid, "b")
self.assertEqual(picked_decode.nodeid, "d")
with self.assertRaises(Exception):
scheduler.select_pd(req, prefill_nodes, "unknown")
class InferSchedulerTest(SplitWiseSchedulerTestCase):
def _make_config(self, **overrides: Any) -> Any:
base = dict(
enable_chunked_prefill=True,
max_num_partial_prefills=3,
max_long_partial_prefills=1,
max_model_len=200,
)
base.update(overrides)
return self.module.SplitWiseSchedulerConfig(**base)
def test_get_requests_limits_partial_prefills(self) -> None:
config = self._make_config(long_prefill_token_threshold=5)
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
long = self.module.Request("req-long", prompt_token_ids_len=10)
longer = self.module.Request("req-longer", prompt_token_ids_len=12)
infer.reqs_queue.extend([longer, long])
picked = infer.get_requests(
available_blocks=100,
block_size=4,
reserved_output_blocks=1,
max_num_batched_tokens=100,
batch=5,
)
self.assertEqual([req.request_id for req in picked], ["req-longer"])
self.assertEqual([req.request_id for req in infer.reqs_queue], ["req-long"])
def test_get_requests_non_chunked_uses_token_cap(self) -> None:
config = self._make_config(enable_chunked_prefill=False)
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
infer.reqs_queue.extend(
[
self.module.Request("req-1", prompt_token_ids_len=10),
self.module.Request("req-2", prompt_token_ids_len=20),
]
)
picked = infer.get_requests(
available_blocks=100,
block_size=4,
reserved_output_blocks=1,
max_num_batched_tokens=15,
batch=5,
)
self.assertEqual([req.request_id for req in picked], ["req-1"])
self.assertEqual(len(infer.reqs_queue), 1)
def test_put_results_groups_by_writer_index(self) -> None:
config = self._make_config()
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
infer.writers = [_Writer(), _Writer()]
infer.node.add_req("req#0#g", 1)
metrics = self.module.RequestMetrics(arrival_time=time.time())
result = self.module.RequestOutput(
request_id="req#0#g",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=True,
)
infer.put_results([result])
self.assertEqual(len(infer.writers[0].items), 1)
key, payloads = infer.writers[0].items[0]
self.assertEqual(key, "g")
decoded = self.module.orjson.loads(payloads[0])
self.assertFalse(decoded["finished"])
def test_put_results_handles_errors(self) -> None:
config = self._make_config()
infer = self.module.InferScheduler(config)
infer.role = "decode"
infer.node = self.module.NodeInfo("n", "decode", "h", {"transfer_protocol": ["ipc"]}, load=0)
infer.writers = [_Writer()]
infer.node.add_req("bad#0#", 1)
metrics = self.module.RequestMetrics(arrival_time=time.time())
result = self.module.RequestOutput(
request_id="bad#0#",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]),
metrics=metrics,
finished=True,
error_code=500,
)
infer.put_results([result])
self.assertFalse(infer.node.reqs)
def test_start_initializes_writers(self) -> None:
config = self._make_config()
infer = self.module.InferScheduler(config)
infer.start("prefill", "host", {"transfer_protocol": ["ipc"]})
self.assertEqual(len(infer.writers), config.writer_parallel)
def test_get_requests_skips_expired_entries(self) -> None:
config = self._make_config()
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1))
infer.node.add_req("expired", 1)
infer.reqs_queue.append(expired)
picked = infer.get_requests(
available_blocks=10,
block_size=1,
reserved_output_blocks=1,
max_num_batched_tokens=10,
batch=1,
)
self.assertEqual(picked, [])
self.assertNotIn("expired", infer.node.reqs)
def test_check_redis_version_requires_supported_version(self) -> None:
config = self._make_config()
infer = self.module.InferScheduler(config)
infer.client.info = lambda: {"redis_version": "5.0.0"} # type: ignore[assignment]
with self.assertRaises(AssertionError):
infer.check_redis_version()
class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase):
def test_facade_delegates_to_components(self) -> None:
module = self.module
class _FakeAPI:
def __init__(self, _config: Any) -> None:
self.started = False
self.reqs: List[Any] = []
def start(self) -> None:
self.started = True
def put_requests(self, reqs: List[Any]):
self.reqs.extend(reqs)
return [(req.request_id, None) for req in reqs]
def get_results(self):
return {"x": 1}
class _FakeInfer:
def __init__(self, _config: Any) -> None:
self.started = False
self.nodeid = None
def start(self, role, host, disaggregated):
self.started = True
def get_requests(self, *args, **kwargs):
return ["scheduled"]
def put_results(self, results):
return list(results)
original_api = module.APIScheduler
original_infer = module.InferScheduler
module.APIScheduler = _FakeAPI # type: ignore[assignment]
module.InferScheduler = _FakeInfer # type: ignore[assignment]
try:
config = module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
facade = module.SplitWiseScheduler(config)
facade.start("prefill", "host", {"tp": "ipc"})
self.assertTrue(facade.scheduler.started)
self.assertTrue(facade.infer.started)
reqs = [module.Request("req", prompt_token_ids_len=1)]
result = facade.put_requests(reqs)
self.assertEqual(result[0][0], "req")
self.assertEqual(facade.get_results(), {"x": 1})
scheduled = facade.get_requests(10, 1, 1, 10, batch=1)
self.assertEqual(scheduled, ["scheduled"])
outputs = facade.put_results([1, 2])
self.assertEqual(outputs, [1, 2])
finally:
module.APIScheduler = original_api # type: ignore[assignment]
module.InferScheduler = original_infer # type: ignore[assignment]
def test_get_requests_with_insufficient_resources(self) -> None:
module = self.module
config = module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
facade = module.SplitWiseScheduler(config)
facade.infer = types.SimpleNamespace(get_requests=lambda *args, **kwargs: ["should not reach"])
facade.scheduler = types.SimpleNamespace()
result = facade.get_requests(
available_blocks=1, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10
)
self.assertEqual(result, [])
result = facade.get_requests(
available_blocks=10, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10, batch=0
)
self.assertEqual(result, [])
def test_start_uses_real_components(self) -> None:
module = self.module
config = module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
facade = module.SplitWiseScheduler(config)
infer_flags = {}
scheduler_flags = {}
facade.infer = types.SimpleNamespace(
start=lambda role, host, disagg: infer_flags.setdefault("called", (role, host, disagg)),
)
facade.scheduler = types.SimpleNamespace(start=lambda: scheduler_flags.setdefault("called", True))
facade.start("prefill", "host", {"mode": "ipc"})
self.assertEqual(infer_flags["called"], ("prefill", "host", {"mode": "ipc"}))
self.assertTrue(scheduler_flags["called"])
facade.reset_nodeid("new-id")
self.assertEqual(facade.scheduler.nodeid, "new-id")
class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
def test_result_writer_start_flags_thread(self) -> None:
client = sys.modules["redis"].Redis()
writer = self.module.ResultWriter(client, idx=0, batch=2, ttl=5)
writer.start()
self.assertTrue(writer.thread.started)
def test_result_writer_run_single_iteration(self) -> None:
client = sys.modules["redis"].Redis()
writer = self.module.ResultWriter(client, idx=0, batch=5, ttl=10)
with writer.cond:
writer.data.appendleft(("key", b"payload"))
class _Pipeline:
def __init__(self, parent):
self.parent = parent
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def multi(self):
return self
def lpush(self, key, *items):
self.parent.lpush(key, *items)
return self
def expire(self, key, ttl):
raise SystemExit()
def execute(self):
return None
client.pipeline = lambda: _Pipeline(client) # type: ignore[assignment]
with self.assertRaises(SystemExit):
writer.run()
def test_result_writer_run_groups_batches(self) -> None:
client = sys.modules["redis"].Redis()
writer = self.module.ResultWriter(client, idx=0, batch=10, ttl=5)
with writer.cond:
writer.data.appendleft(("k1", b"a"))
writer.data.appendleft(("k1", b"b"))
writer.data.appendleft(("k2", b"c"))
def _pipeline():
class _P:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def multi(self):
return self
def lpush(self, key, *items):
client.lpush(key, *items)
return self
def expire(self, key, ttl):
raise SystemExit()
def execute(self):
return None
return _P()
client.pipeline = _pipeline # type: ignore[assignment]
with self.assertRaises(SystemExit):
writer.run()
def test_infer_scheduler_routine_report(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
infer = self.module.InferScheduler(config)
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
def _fake_hset(*_args, **_kwargs):
raise ValueError("fail")
infer.client.hset = _fake_hset # type: ignore[assignment]
original_logger = self.module.logger.error
self.module.logger.error = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit())
try:
with self.assertRaises(SystemExit):
infer.routine_report()
finally:
self.module.logger.error = original_logger
def test_infer_scheduler_loop_expire_reqs(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
infer = self.module.InferScheduler(config)
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
def _raise_exit(ttl):
raise SystemExit()
infer.node.expire_reqs = _raise_exit # type: ignore[assignment]
with self.assertRaises(SystemExit):
infer.loop_expire_reqs()
def test_infer_scheduler_loop_get_reqs(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=10,
)
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
req = self.module.Request("rq", prompt_token_ids_len=3)
payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5)
key = f"ReqQ_{infer.nodeid}"
infer.client.storage[key] = [payload]
state = {"called": False}
def _fake_rpop(k, batch):
if not state["called"]:
state["called"] = True
return infer.client.storage[k][:]
raise SystemExit()
infer.client.rpop = _fake_rpop # type: ignore[assignment]
infer.client.brpop = lambda *_args, **_kwargs: None # type: ignore[assignment]
with self.assertRaises(SystemExit):
infer.loop_get_reqs()
def test_infer_scheduler_get_requests_limits(self) -> None:
config = self.module.SplitWiseSchedulerConfig(
enable_chunked_prefill=True,
max_num_partial_prefills=1,
max_long_partial_prefills=1,
max_model_len=50,
)
infer = self.module.InferScheduler(config)
infer.role = "prefill"
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
heavy = self.module.Request("heavy", prompt_token_ids_len=10)
infer.reqs_queue.append(heavy)
picked = infer.get_requests(
available_blocks=1,
block_size=4,
reserved_output_blocks=1,
max_num_batched_tokens=100,
batch=1,
)
self.assertEqual(picked, [])
infer.reqs_queue.clear()
infer.reqs_queue.append(
self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10)
)
infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2))
selected = infer.get_requests(
available_blocks=100,
block_size=4,
reserved_output_blocks=1,
max_num_batched_tokens=100,
batch=2,
)
self.assertEqual(len(selected), 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--print-coverage-command", action="store_true")
known_args, remaining = parser.parse_known_args()
if known_args.print_coverage_command:
print("python -m coverage run -m unittest tests.scheduler.test_splitwise_scheduler")
print("python -m coverage report -m --include='fastdeploy/scheduler/splitwise_scheduler.py'")
unittest.main(argv=[sys.argv[0]] + remaining)