mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* Add unit tests for SplitWiseScheduler module * Add info and ping to fake redis client for tests * Document fake redis metadata methods in tests * Enhance splitwise scheduler tests * Clean up test_splitwise_scheduler.py Removed copyright notice and documentation comments. * Simplify splitwise scheduler test stubs * Refine splitwise scheduler tests * Handle empty result keys with restored sleep --------- Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
1307 lines
48 KiB
Python
1307 lines
48 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import importlib
|
|
import pickle
|
|
import random
|
|
import sys
|
|
import time
|
|
import types
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
|
def _install_stub_modules() -> None:
|
|
"""Install lightweight stand-ins for the external dependencies."""
|
|
|
|
if getattr(_install_stub_modules, "_installed", False):
|
|
return
|
|
|
|
# --------------------------------------------------------------- Redis stubs
|
|
class _FakePipeline:
|
|
def __init__(self, client: "_FakeRedis") -> None:
|
|
self._client = client
|
|
self._commands: list[tuple[str, tuple[Any, ...]]] = []
|
|
|
|
def __enter__(self) -> "_FakePipeline":
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
|
|
return None
|
|
|
|
def multi(self) -> "_FakePipeline":
|
|
return self
|
|
|
|
def lpush(self, key: str, *values: Any) -> "_FakePipeline":
|
|
self._commands.append(("lpush", (key, values)))
|
|
return self
|
|
|
|
def expire(self, key: str, ttl: int) -> "_FakePipeline":
|
|
self._commands.append(("expire", (key, ttl)))
|
|
return self
|
|
|
|
def execute(self) -> None:
|
|
for name, params in self._commands:
|
|
if name == "lpush":
|
|
key, values = params
|
|
self._client.lpush(key, *values)
|
|
elif name == "expire":
|
|
key, ttl = params
|
|
self._client.expire(key, ttl)
|
|
self._commands.clear()
|
|
|
|
class _FakeRedis:
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
self.storage: dict[str, list[Any]] = {}
|
|
self.hashes: dict[str, dict[Any, Any]] = {}
|
|
self.expirations: dict[str, int] = {}
|
|
|
|
# ------------------------------- list operations used by the scheduler
|
|
def lpush(self, key: str, *values: Any) -> None:
|
|
items = list(values)
|
|
if not items:
|
|
return
|
|
bucket = self.storage.setdefault(key, [])
|
|
for value in items:
|
|
bucket.insert(0, value)
|
|
|
|
def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]:
|
|
bucket = self.storage.get(key)
|
|
if not bucket:
|
|
return None
|
|
if count is None:
|
|
return [bucket.pop()]
|
|
count = min(count, len(bucket))
|
|
values = [bucket.pop() for _ in range(count)]
|
|
return values
|
|
|
|
def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override]
|
|
for key in keys:
|
|
bucket = self.storage.get(key)
|
|
if bucket:
|
|
return (key, bucket.pop())
|
|
return None
|
|
|
|
# ------------------------------------------ hash operations for cluster
|
|
def hset(self, key: str, field: str, value: Any) -> None:
|
|
self.hashes.setdefault(key, {})[field] = value
|
|
|
|
def hgetall(self, key: str) -> dict[Any, Any]:
|
|
return {k: v for k, v in self.hashes.get(key, {}).items()}
|
|
|
|
def hdel(self, key: str, field: str) -> None:
|
|
if key in self.hashes:
|
|
self.hashes[key].pop(field, None)
|
|
|
|
# -------------------------------------------------------------- misc ops
|
|
def expire(self, key: str, ttl: int) -> None:
|
|
self.expirations[key] = ttl
|
|
|
|
def pipeline(self) -> _FakePipeline:
|
|
return _FakePipeline(self)
|
|
|
|
# Metadata required by InferScheduler.check_redis_version
|
|
def info(self) -> dict[str, str]:
|
|
return {"redis_version": "6.2.0"}
|
|
|
|
# Health check used by InferScheduler.start
|
|
def ping(self) -> bool:
|
|
return True
|
|
|
|
redis_mod = types.ModuleType("redis")
|
|
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
|
sys.modules.setdefault("redis", redis_mod)
|
|
|
|
# ------------------------------------------- fastdeploy.engine.request stub
|
|
request_mod = types.ModuleType("fastdeploy.engine.request")
|
|
|
|
@dataclass
|
|
class CompletionOutput:
|
|
index: int
|
|
send_idx: int
|
|
token_ids: List[int]
|
|
finished: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"index": self.index,
|
|
"send_idx": self.send_idx,
|
|
"token_ids": list(self.token_ids),
|
|
"finished": self.finished,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput":
|
|
return cls(
|
|
index=data.get("index", 0),
|
|
send_idx=data.get("send_idx", 0),
|
|
token_ids=list(data.get("token_ids", [])),
|
|
finished=data.get("finished", False),
|
|
)
|
|
|
|
@dataclass
|
|
class RequestMetrics:
|
|
arrival_time: float
|
|
inference_start_time: Optional[float] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"arrival_time": self.arrival_time,
|
|
"inference_start_time": self.inference_start_time,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics":
|
|
return cls(
|
|
arrival_time=data.get("arrival_time", time.time()),
|
|
inference_start_time=data.get("inference_start_time"),
|
|
)
|
|
|
|
class Request:
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
prompt: Optional[str] = None,
|
|
prompt_token_ids: Optional[List[int]] = None,
|
|
prompt_token_ids_len: int = 0,
|
|
arrival_time: Optional[float] = None,
|
|
disaggregate_info: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.prompt = prompt or ""
|
|
self.prompt_token_ids = prompt_token_ids or []
|
|
self.prompt_token_ids_len = prompt_token_ids_len
|
|
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
|
self.metrics = RequestMetrics(arrival_time=self.arrival_time)
|
|
self.disaggregate_info = disaggregate_info
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"request_id": self.request_id,
|
|
"prompt": self.prompt,
|
|
"prompt_token_ids": list(self.prompt_token_ids),
|
|
"prompt_token_ids_len": self.prompt_token_ids_len,
|
|
"arrival_time": self.arrival_time,
|
|
"metrics": self.metrics.to_dict(),
|
|
"disaggregate_info": self.disaggregate_info,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Request":
|
|
req = cls(
|
|
request_id=data["request_id"],
|
|
prompt=data.get("prompt"),
|
|
prompt_token_ids=data.get("prompt_token_ids"),
|
|
prompt_token_ids_len=data.get("prompt_token_ids_len", 0),
|
|
arrival_time=data.get("arrival_time", time.time()),
|
|
disaggregate_info=data.get("disaggregate_info"),
|
|
)
|
|
metrics_dict = data.get("metrics")
|
|
if metrics_dict:
|
|
req.metrics = RequestMetrics.from_dict(metrics_dict)
|
|
else:
|
|
req.refresh_metrics()
|
|
return req
|
|
|
|
def refresh_metrics(self) -> None:
|
|
self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time})
|
|
|
|
class RequestOutput:
|
|
def __init__(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
prompt_token_ids: List[int],
|
|
outputs: CompletionOutput,
|
|
metrics: RequestMetrics,
|
|
finished: bool = False,
|
|
error_code: int = 200,
|
|
error_msg: Optional[str] = None,
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.prompt = prompt
|
|
self.prompt_token_ids = prompt_token_ids
|
|
self.outputs = outputs
|
|
self.metrics = metrics
|
|
self.finished = finished
|
|
self.error_code = error_code
|
|
self.error_msg = error_msg
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"request_id": self.request_id,
|
|
"prompt": self.prompt,
|
|
"prompt_token_ids": list(self.prompt_token_ids),
|
|
"outputs": self.outputs.to_dict(),
|
|
"metrics": self.metrics.to_dict(),
|
|
"finished": self.finished,
|
|
"error_code": self.error_code,
|
|
"error_msg": self.error_msg,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput":
|
|
return cls(
|
|
request_id=data["request_id"],
|
|
prompt=data.get("prompt", ""),
|
|
prompt_token_ids=list(data.get("prompt_token_ids", [])),
|
|
outputs=CompletionOutput.from_dict(data.get("outputs", {})),
|
|
metrics=RequestMetrics.from_dict(data.get("metrics", {})),
|
|
finished=data.get("finished", False),
|
|
error_code=data.get("error_code", 200),
|
|
error_msg=data.get("error_msg"),
|
|
)
|
|
|
|
request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined]
|
|
request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined]
|
|
request_mod.Request = Request # type: ignore[attr-defined]
|
|
request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined]
|
|
sys.modules["fastdeploy.engine.request"] = request_mod
|
|
|
|
fd_pkg = types.ModuleType("fastdeploy")
|
|
fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")]
|
|
sys.modules["fastdeploy"] = fd_pkg
|
|
|
|
scheduler_pkg = types.ModuleType("fastdeploy.scheduler")
|
|
scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")]
|
|
sys.modules["fastdeploy.scheduler"] = scheduler_pkg
|
|
|
|
logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger")
|
|
|
|
def _log(*_args: Any, **_kwargs: Any) -> None:
|
|
return None
|
|
|
|
for level in ("info", "error", "debug", "warning"):
|
|
setattr(logger_mod, level, _log) # type: ignore[attr-defined]
|
|
sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod
|
|
|
|
utils_mod = types.ModuleType("fastdeploy.utils")
|
|
utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined]
|
|
sys.modules["fastdeploy.utils"] = utils_mod
|
|
|
|
_install_stub_modules._installed = True
|
|
|
|
|
|
def _import_splitwise_scheduler():
|
|
"""Import the scheduler module with the stub environment."""
|
|
|
|
_install_stub_modules()
|
|
return importlib.import_module("fastdeploy.scheduler.splitwise_scheduler")
|
|
|
|
|
|
class _PatchedThread:
|
|
def __init__(self, *args: Any, target=None, **kwargs: Any) -> None: # type: ignore[override]
|
|
self._target = target
|
|
self.started = False
|
|
|
|
def start(self) -> None:
|
|
self.started = True
|
|
|
|
|
|
class _Writer:
|
|
def __init__(self) -> None:
|
|
self.items: list[tuple[str, list[bytes]]] = []
|
|
|
|
def put(self, key: str, items: list[bytes]) -> None:
|
|
self.items.append((key, items))
|
|
|
|
|
|
class SplitWiseSchedulerTestCase(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.module = _import_splitwise_scheduler()
|
|
self._orig_thread = self.module.threading.Thread
|
|
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
|
|
|
|
def tearDown(self) -> None:
|
|
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
|
|
|
|
|
|
class SplitWiseSchedulerConfigTest(SplitWiseSchedulerTestCase):
|
|
def test_threshold_defaults_to_model_ratio(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=5,
|
|
max_long_partial_prefills=3,
|
|
max_model_len=1000,
|
|
)
|
|
self.assertEqual(config.long_prefill_token_threshold, 40)
|
|
self.assertEqual(config.expire_period, 3.0)
|
|
|
|
def test_check_and_print_cover_logging(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=50,
|
|
)
|
|
config.check()
|
|
config.print()
|
|
|
|
|
|
class NodeInfoTest(SplitWiseSchedulerTestCase):
|
|
def test_serialization_and_expiration(self) -> None:
|
|
node = self.module.NodeInfo(
|
|
nodeid="node-1",
|
|
role="prefill",
|
|
host="localhost",
|
|
disaggregated={"transfer_protocol": ["ipc", "rdma"]},
|
|
load=2,
|
|
)
|
|
|
|
payload = node.serialize()
|
|
loaded = self.module.NodeInfo.load_from("node-1", payload)
|
|
self.assertFalse(loaded.expired(10))
|
|
|
|
loaded.ts -= 20
|
|
self.assertTrue(loaded.expired(1))
|
|
|
|
loaded.add_req("req-1", 4)
|
|
self.assertIn("req-1", loaded.reqs)
|
|
|
|
loaded.update_req_timestamp(["req-1"])
|
|
before = loaded.reqs["req-1"][1]
|
|
loaded.reqs["req-1"][1] -= 1000
|
|
loaded.expire_reqs(ttl=1)
|
|
self.assertNotIn("req-1", loaded.reqs)
|
|
|
|
loaded.add_req("req-2", 2)
|
|
loaded.finish_req("req-2")
|
|
self.assertNotIn("req-2", loaded.reqs)
|
|
self.assertNotEqual(before, loaded.ts)
|
|
|
|
def test_comparisons(self) -> None:
|
|
low = self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
|
|
high = self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5)
|
|
self.assertTrue(low < high)
|
|
self.assertIn("a(1)", repr(low))
|
|
|
|
|
|
class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|
def test_read_handles_group_tokens_with_buffer_and_outputs(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
|
|
|
req = self.module.Request("req-buffer", prompt_token_ids_len=2)
|
|
reader.add_req(req)
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
head = self.module.RequestOutput(
|
|
request_id=req.request_id,
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
|
metrics=metrics,
|
|
finished=False,
|
|
)
|
|
buffered = self.module.RequestOutput(
|
|
request_id=req.request_id,
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]),
|
|
metrics=metrics,
|
|
finished=False,
|
|
)
|
|
trailing = self.module.RequestOutput(
|
|
request_id=req.request_id,
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
|
|
with reader.lock:
|
|
reader.out_buffer[req.request_id] = [buffered]
|
|
|
|
reader.data.appendleft(head)
|
|
reader.data.appendleft(trailing)
|
|
outputs = reader.read()
|
|
self.assertIn(req.request_id, outputs)
|
|
|
|
# Triggers the path where group_tokens has no pre-existing output bucket
|
|
# so the branch at lines 353-354 is exercised.
|
|
another = self.module.RequestOutput(
|
|
request_id="req-new",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
reader.data.appendleft(another)
|
|
outputs = reader.read()
|
|
self.assertEqual(outputs["req-new"][0].outputs.token_ids, [9])
|
|
|
|
def test_read_groups_partial_outputs(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
|
|
|
|
req = self.module.Request("req-A", prompt_token_ids_len=3)
|
|
reader.add_req(req)
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
first = self.module.RequestOutput(
|
|
request_id="req-A",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
|
|
metrics=metrics,
|
|
finished=False,
|
|
)
|
|
follow = self.module.RequestOutput(
|
|
request_id="req-A",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
|
|
reader.data.appendleft(follow)
|
|
reader.data.appendleft(first)
|
|
|
|
outputs = reader.read()
|
|
self.assertIn("req-A", outputs)
|
|
self.assertEqual(len(outputs["req-A"]), 2)
|
|
|
|
def test_sync_results_converts_payloads(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
ro = self.module.RequestOutput(
|
|
request_id="req-B",
|
|
prompt="p",
|
|
prompt_token_ids=[1],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
|
|
payload = self.module.orjson.dumps(ro.to_dict())
|
|
client.storage.setdefault("req-key", []).append(payload)
|
|
|
|
total = reader.sync_results(["req-key"])
|
|
self.assertEqual(total, 1)
|
|
self.assertTrue(reader.data)
|
|
|
|
def test_read_uses_out_buffer(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
|
|
|
req = self.module.Request("req-out", prompt_token_ids_len=2)
|
|
reader.add_req(req)
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
head = self.module.RequestOutput(
|
|
request_id="req-out",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
|
metrics=metrics,
|
|
finished=False,
|
|
)
|
|
tail = self.module.RequestOutput(
|
|
request_id="req-out",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
|
|
with reader.lock:
|
|
reader.out_buffer[req.request_id] = [tail]
|
|
reader.data.appendleft(head)
|
|
|
|
outputs = reader.read()
|
|
self.assertEqual(len(outputs["req-out"]), 2)
|
|
|
|
def test_sync_results_with_group_override(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
ro = self.module.RequestOutput(
|
|
request_id="req-group",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
payload = self.module.orjson.dumps(ro.to_dict())
|
|
client.storage.setdefault("grp", []).append(payload)
|
|
|
|
total = reader.sync_results(["unused"])
|
|
self.assertEqual(total, 1)
|
|
self.assertEqual(reader.data[-1].request_id, "req-group")
|
|
|
|
def test_run_emits_expired_placeholder(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=1, group="")
|
|
reader.reqs["old"] = {"arrival_time": time.time() - 5}
|
|
reader.reqs["active"] = {"arrival_time": time.time()}
|
|
|
|
call_count = {"rpop": 0}
|
|
|
|
def _rpop(key: str, batch: int):
|
|
call_count["rpop"] += 1
|
|
if call_count["rpop"] > 1:
|
|
raise SystemExit()
|
|
return []
|
|
|
|
reader.client.rpop = _rpop # type: ignore[assignment]
|
|
|
|
with self.assertRaises(SystemExit):
|
|
reader.run()
|
|
|
|
self.assertNotIn("old", reader.reqs)
|
|
self.assertTrue(reader.data)
|
|
self.assertGreaterEqual(call_count["rpop"], 1)
|
|
|
|
def test_run_handles_empty_keys_and_exceptions(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
reader = self.module.ResultReader(client, idx=0, batch=5, ttl=10, group="")
|
|
reader.reqs.clear()
|
|
|
|
original_sleep = self.module.time.sleep
|
|
try:
|
|
self.module.time.sleep = lambda _t: (_ for _ in ()).throw(SystemExit())
|
|
with self.assertRaises(SystemExit):
|
|
reader.run()
|
|
finally:
|
|
self.module.time.sleep = original_sleep
|
|
|
|
# Now cover the exception logging path inside run()
|
|
reader.reqs["rid"] = {"arrival_time": time.time()}
|
|
calls = {"count": 0}
|
|
|
|
def _rpop(_key: str, _batch: int):
|
|
calls["count"] += 1
|
|
if calls["count"] == 1:
|
|
raise ValueError("boom")
|
|
raise SystemExit()
|
|
|
|
reader.client.rpop = _rpop # type: ignore[assignment]
|
|
with self.assertRaises(SystemExit):
|
|
reader.run()
|
|
|
|
self.assertGreaterEqual(calls["count"], 2)
|
|
|
|
|
|
class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|
def _make_config(self) -> Any:
|
|
return self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=5,
|
|
max_long_partial_prefills=3,
|
|
max_model_len=200,
|
|
)
|
|
|
|
def test_schedule_mixed_node_uses_single_queue(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
|
|
req = self.module.Request("req-1", prompt_token_ids_len=10)
|
|
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
|
|
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
|
|
|
|
scheduler.schedule(req, [mixed], [], [], group="g0")
|
|
key = f"ReqQ_{mixed.nodeid}"
|
|
self.assertIn(key, scheduler.client.storage)
|
|
stored = scheduler.client.storage[key][0]
|
|
decoded = pickle.loads(stored)
|
|
self.assertEqual(decoded["group"], "g0")
|
|
self.assertIsNone(decoded["disaggregate_info"])
|
|
|
|
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
|
|
req = self.module.Request("req-meta", prompt_token_ids_len=10)
|
|
disagg = {
|
|
"host_ip": "1.1.1.1",
|
|
"transfer_protocol": ["ipc"],
|
|
"connector_port": 10,
|
|
"device_ids": [0],
|
|
"rdma_ports": [100],
|
|
"tp_size": 2,
|
|
}
|
|
pre = self.module.NodeInfo("p", "prefill", "host-a", disagg, load=1)
|
|
dec = self.module.NodeInfo(
|
|
"d",
|
|
"decode",
|
|
"host-b",
|
|
{
|
|
"host_ip": "1.1.1.1",
|
|
"transfer_protocol": ["ipc"],
|
|
"connector_port": 11,
|
|
"device_ids": [1],
|
|
"rdma_ports": [101],
|
|
"tp_size": 2,
|
|
},
|
|
load=2,
|
|
)
|
|
|
|
scheduler.schedule(req, [pre], [dec], [], group="g1")
|
|
self.assertIsNotNone(req.disaggregate_info)
|
|
self.assertEqual(req.disaggregate_info["transfer_protocol"], "ipc")
|
|
self.assertIn(f"ReqQ_{pre.nodeid}", scheduler.client.storage)
|
|
self.assertIn(f"ReqQ_{dec.nodeid}", scheduler.client.storage)
|
|
|
|
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1))
|
|
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
|
|
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
|
|
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
|
|
|
|
with self.assertRaises(SystemExit):
|
|
scheduler.loop_schedule()
|
|
|
|
def test_sync_cluster_partitions_and_filters_nodes(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
now = time.time()
|
|
cluster_key = scheduler.cluster_key
|
|
|
|
expired_payload = self.module.orjson.dumps(
|
|
{"ts": now - 10, "role": "prefill", "load": 1, "host": "h", "disaggregated": {}}
|
|
)
|
|
scheduler.client.hset(cluster_key, b"expired", expired_payload)
|
|
|
|
valid_prefill = self.module.NodeInfo("p1", "prefill", "h1", {"transfer_protocol": ["ipc"]}, load=1)
|
|
scheduler.client.hset(cluster_key, b"p1", valid_prefill.serialize())
|
|
|
|
valid_decode = self.module.NodeInfo("d1", "decode", "h2", {"transfer_protocol": ["ipc"]}, load=2)
|
|
scheduler.client.hset(cluster_key, b"d1", valid_decode.serialize())
|
|
|
|
invalid_payload = self.module.orjson.dumps(
|
|
{"ts": now, "role": "unknown", "load": 0, "host": "h3", "disaggregated": {}}
|
|
)
|
|
scheduler.client.hset(cluster_key, b"bad", invalid_payload)
|
|
|
|
pnodes, dnodes, mnodes = scheduler.sync_cluster()
|
|
self.assertEqual(len(pnodes), 1)
|
|
self.assertEqual(len(dnodes), 1)
|
|
self.assertEqual(len(mnodes), 0)
|
|
|
|
def test_loop_clear_expired_nodes_removes_entries(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
cluster_key = scheduler.cluster_key
|
|
stale_payload = self.module.orjson.dumps(
|
|
{
|
|
"ts": time.time() - (scheduler.clear_expired_nodes_period + 5),
|
|
"role": "prefill",
|
|
"load": 0,
|
|
"host": "h",
|
|
"disaggregated": {"transfer_protocol": ["ipc"]},
|
|
}
|
|
)
|
|
scheduler.client.hset(cluster_key, b"old", stale_payload)
|
|
|
|
original_sleep = self.module.time.sleep
|
|
self.module.time.sleep = lambda _t: (_ for _ in ()).throw(SystemExit())
|
|
with self.assertRaises(SystemExit):
|
|
scheduler.loop_clear_expired_nodes()
|
|
self.module.time.sleep = original_sleep
|
|
self.assertNotIn(b"old", scheduler.client.hgetall(cluster_key))
|
|
|
|
def test_select_pd_paths(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
req = self.module.Request("req-sel", prompt_token_ids_len=50)
|
|
nodes = [
|
|
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
|
|
]
|
|
random.seed(0)
|
|
chosen = scheduler.select_pd(req, nodes, "prefill")
|
|
self.assertIn(chosen, nodes)
|
|
|
|
decode_nodes = [
|
|
self.module.NodeInfo(str(i), "decode", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(2)
|
|
]
|
|
chosen_decode = scheduler.select_pd(req, decode_nodes, "decode")
|
|
self.assertIn(chosen_decode, decode_nodes)
|
|
|
|
def test_schedule_disaggregated_updates_protocol(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
|
|
req = self.module.Request("req-2", prompt_token_ids_len=10)
|
|
prefill = self.module.NodeInfo(
|
|
"prefill",
|
|
"prefill",
|
|
"host-a",
|
|
{
|
|
"transfer_protocol": ["ipc"],
|
|
"host_ip": "1.1.1.1",
|
|
"connector_port": 10,
|
|
"device_ids": [0],
|
|
"rdma_ports": [1],
|
|
"tp_size": 1,
|
|
},
|
|
load=1,
|
|
)
|
|
decode = self.module.NodeInfo(
|
|
"decode",
|
|
"decode",
|
|
"host-b",
|
|
{
|
|
"transfer_protocol": ["ipc", "rdma"],
|
|
"host_ip": "2.2.2.2",
|
|
"connector_port": 11,
|
|
"device_ids": [1],
|
|
"rdma_ports": [2],
|
|
"tp_size": 1,
|
|
},
|
|
load=1,
|
|
)
|
|
|
|
def _select(req_obj, nodes, role):
|
|
return nodes[0]
|
|
|
|
scheduler.select_pd = _select # type: ignore[assignment]
|
|
|
|
scheduler.schedule(req, [prefill], [decode], [], group="")
|
|
self.assertIn("ReqQ_prefill", scheduler.client.storage)
|
|
self.assertIn("ReqQ_decode", scheduler.client.storage)
|
|
|
|
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
|
|
self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma")
|
|
|
|
def test_sync_cluster_filters_expired_nodes(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
|
|
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
|
|
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
|
|
|
|
stale_payload = self.module.orjson.dumps(
|
|
{
|
|
"ts": time.time() - (config.expire_period + 1),
|
|
"role": "prefill",
|
|
"load": 1,
|
|
"host": "h",
|
|
"disaggregated": {"transfer_protocol": ["ipc"]},
|
|
}
|
|
)
|
|
scheduler.client.hset(scheduler.cluster_key, b"n2", stale_payload)
|
|
|
|
pnodes, _, _ = scheduler.sync_cluster()
|
|
self.assertEqual([node.nodeid for node in pnodes], ["n1"])
|
|
|
|
def test_start_put_and_get_results(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
scheduler.start()
|
|
|
|
reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)]
|
|
result = scheduler.put_requests(reqs)
|
|
self.assertEqual(len(result), 2)
|
|
|
|
fake_output = {"a": ["value"]}
|
|
scheduler.readers = [types.SimpleNamespace(read=lambda: fake_output)]
|
|
outputs = scheduler.get_results()
|
|
self.assertEqual(outputs, fake_output)
|
|
|
|
def test_select_pd_prefill_and_decode(self) -> None:
|
|
config = self._make_config()
|
|
scheduler = self.module.APIScheduler(config)
|
|
|
|
req = self.module.Request("req-select", prompt_token_ids_len=50)
|
|
prefill_nodes = [
|
|
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
|
|
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
|
|
]
|
|
decode_nodes = [
|
|
self.module.NodeInfo("c", "decode", "h", {"transfer_protocol": ["ipc"]}, load=1),
|
|
self.module.NodeInfo("d", "decode", "h", {"transfer_protocol": ["ipc"]}, load=2),
|
|
]
|
|
|
|
original_choice = self.module.random.choice
|
|
self.module.random.choice = lambda seq: seq[-1] # type: ignore[assignment]
|
|
try:
|
|
picked_prefill = scheduler.select_pd(req, prefill_nodes, "prefill")
|
|
picked_decode = scheduler.select_pd(req, decode_nodes, "decode")
|
|
finally:
|
|
self.module.random.choice = original_choice
|
|
|
|
self.assertEqual(picked_prefill.nodeid, "b")
|
|
self.assertEqual(picked_decode.nodeid, "d")
|
|
|
|
with self.assertRaises(Exception):
|
|
scheduler.select_pd(req, prefill_nodes, "unknown")
|
|
|
|
|
|
class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|
def _make_config(self, **overrides: Any) -> Any:
|
|
base = dict(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=3,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=200,
|
|
)
|
|
base.update(overrides)
|
|
return self.module.SplitWiseSchedulerConfig(**base)
|
|
|
|
def test_get_requests_limits_partial_prefills(self) -> None:
|
|
config = self._make_config(long_prefill_token_threshold=5)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
long = self.module.Request("req-long", prompt_token_ids_len=10)
|
|
longer = self.module.Request("req-longer", prompt_token_ids_len=12)
|
|
infer.reqs_queue.extend([longer, long])
|
|
|
|
picked = infer.get_requests(
|
|
available_blocks=100,
|
|
block_size=4,
|
|
reserved_output_blocks=1,
|
|
max_num_batched_tokens=100,
|
|
batch=5,
|
|
)
|
|
self.assertEqual([req.request_id for req in picked], ["req-longer"])
|
|
self.assertEqual([req.request_id for req in infer.reqs_queue], ["req-long"])
|
|
|
|
def test_get_requests_non_chunked_uses_token_cap(self) -> None:
|
|
config = self._make_config(enable_chunked_prefill=False)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
infer.reqs_queue.extend(
|
|
[
|
|
self.module.Request("req-1", prompt_token_ids_len=10),
|
|
self.module.Request("req-2", prompt_token_ids_len=20),
|
|
]
|
|
)
|
|
|
|
picked = infer.get_requests(
|
|
available_blocks=100,
|
|
block_size=4,
|
|
reserved_output_blocks=1,
|
|
max_num_batched_tokens=15,
|
|
batch=5,
|
|
)
|
|
self.assertEqual([req.request_id for req in picked], ["req-1"])
|
|
self.assertEqual(len(infer.reqs_queue), 1)
|
|
|
|
def test_put_results_groups_by_writer_index(self) -> None:
|
|
config = self._make_config()
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
infer.writers = [_Writer(), _Writer()]
|
|
infer.node.add_req("req#0#g", 1)
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
result = self.module.RequestOutput(
|
|
request_id="req#0#g",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
)
|
|
|
|
infer.put_results([result])
|
|
self.assertEqual(len(infer.writers[0].items), 1)
|
|
key, payloads = infer.writers[0].items[0]
|
|
self.assertEqual(key, "g")
|
|
decoded = self.module.orjson.loads(payloads[0])
|
|
self.assertFalse(decoded["finished"])
|
|
|
|
def test_put_results_handles_errors(self) -> None:
|
|
config = self._make_config()
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "decode"
|
|
infer.node = self.module.NodeInfo("n", "decode", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
infer.writers = [_Writer()]
|
|
infer.node.add_req("bad#0#", 1)
|
|
|
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
|
result = self.module.RequestOutput(
|
|
request_id="bad#0#",
|
|
prompt="",
|
|
prompt_token_ids=[],
|
|
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]),
|
|
metrics=metrics,
|
|
finished=True,
|
|
error_code=500,
|
|
)
|
|
|
|
infer.put_results([result])
|
|
self.assertFalse(infer.node.reqs)
|
|
|
|
def test_start_initializes_writers(self) -> None:
|
|
config = self._make_config()
|
|
infer = self.module.InferScheduler(config)
|
|
infer.start("prefill", "host", {"transfer_protocol": ["ipc"]})
|
|
self.assertEqual(len(infer.writers), config.writer_parallel)
|
|
|
|
def test_get_requests_skips_expired_entries(self) -> None:
|
|
config = self._make_config()
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1))
|
|
infer.node.add_req("expired", 1)
|
|
infer.reqs_queue.append(expired)
|
|
|
|
picked = infer.get_requests(
|
|
available_blocks=10,
|
|
block_size=1,
|
|
reserved_output_blocks=1,
|
|
max_num_batched_tokens=10,
|
|
batch=1,
|
|
)
|
|
|
|
self.assertEqual(picked, [])
|
|
self.assertNotIn("expired", infer.node.reqs)
|
|
|
|
def test_check_redis_version_requires_supported_version(self) -> None:
|
|
config = self._make_config()
|
|
infer = self.module.InferScheduler(config)
|
|
infer.client.info = lambda: {"redis_version": "5.0.0"} # type: ignore[assignment]
|
|
|
|
with self.assertRaises(AssertionError):
|
|
infer.check_redis_version()
|
|
|
|
|
|
class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase):
|
|
def test_facade_delegates_to_components(self) -> None:
|
|
module = self.module
|
|
|
|
class _FakeAPI:
|
|
def __init__(self, _config: Any) -> None:
|
|
self.started = False
|
|
self.reqs: List[Any] = []
|
|
|
|
def start(self) -> None:
|
|
self.started = True
|
|
|
|
def put_requests(self, reqs: List[Any]):
|
|
self.reqs.extend(reqs)
|
|
return [(req.request_id, None) for req in reqs]
|
|
|
|
def get_results(self):
|
|
return {"x": 1}
|
|
|
|
class _FakeInfer:
|
|
def __init__(self, _config: Any) -> None:
|
|
self.started = False
|
|
self.nodeid = None
|
|
|
|
def start(self, role, host, disaggregated):
|
|
self.started = True
|
|
|
|
def get_requests(self, *args, **kwargs):
|
|
return ["scheduled"]
|
|
|
|
def put_results(self, results):
|
|
return list(results)
|
|
|
|
original_api = module.APIScheduler
|
|
original_infer = module.InferScheduler
|
|
module.APIScheduler = _FakeAPI # type: ignore[assignment]
|
|
module.InferScheduler = _FakeInfer # type: ignore[assignment]
|
|
|
|
try:
|
|
config = module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
facade = module.SplitWiseScheduler(config)
|
|
|
|
facade.start("prefill", "host", {"tp": "ipc"})
|
|
self.assertTrue(facade.scheduler.started)
|
|
self.assertTrue(facade.infer.started)
|
|
|
|
reqs = [module.Request("req", prompt_token_ids_len=1)]
|
|
result = facade.put_requests(reqs)
|
|
self.assertEqual(result[0][0], "req")
|
|
self.assertEqual(facade.get_results(), {"x": 1})
|
|
|
|
scheduled = facade.get_requests(10, 1, 1, 10, batch=1)
|
|
self.assertEqual(scheduled, ["scheduled"])
|
|
|
|
outputs = facade.put_results([1, 2])
|
|
self.assertEqual(outputs, [1, 2])
|
|
finally:
|
|
module.APIScheduler = original_api # type: ignore[assignment]
|
|
module.InferScheduler = original_infer # type: ignore[assignment]
|
|
|
|
def test_get_requests_with_insufficient_resources(self) -> None:
|
|
module = self.module
|
|
config = module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
facade = module.SplitWiseScheduler(config)
|
|
facade.infer = types.SimpleNamespace(get_requests=lambda *args, **kwargs: ["should not reach"])
|
|
facade.scheduler = types.SimpleNamespace()
|
|
|
|
result = facade.get_requests(
|
|
available_blocks=1, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10
|
|
)
|
|
self.assertEqual(result, [])
|
|
|
|
result = facade.get_requests(
|
|
available_blocks=10, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10, batch=0
|
|
)
|
|
self.assertEqual(result, [])
|
|
|
|
def test_start_uses_real_components(self) -> None:
|
|
module = self.module
|
|
config = module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
facade = module.SplitWiseScheduler(config)
|
|
|
|
infer_flags = {}
|
|
scheduler_flags = {}
|
|
|
|
facade.infer = types.SimpleNamespace(
|
|
start=lambda role, host, disagg: infer_flags.setdefault("called", (role, host, disagg)),
|
|
)
|
|
facade.scheduler = types.SimpleNamespace(start=lambda: scheduler_flags.setdefault("called", True))
|
|
|
|
facade.start("prefill", "host", {"mode": "ipc"})
|
|
self.assertEqual(infer_flags["called"], ("prefill", "host", {"mode": "ipc"}))
|
|
self.assertTrue(scheduler_flags["called"])
|
|
facade.reset_nodeid("new-id")
|
|
self.assertEqual(facade.scheduler.nodeid, "new-id")
|
|
|
|
|
|
class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
|
def test_result_writer_start_flags_thread(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
writer = self.module.ResultWriter(client, idx=0, batch=2, ttl=5)
|
|
writer.start()
|
|
self.assertTrue(writer.thread.started)
|
|
|
|
def test_result_writer_run_single_iteration(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
writer = self.module.ResultWriter(client, idx=0, batch=5, ttl=10)
|
|
with writer.cond:
|
|
writer.data.appendleft(("key", b"payload"))
|
|
|
|
class _Pipeline:
|
|
def __init__(self, parent):
|
|
self.parent = parent
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return None
|
|
|
|
def multi(self):
|
|
return self
|
|
|
|
def lpush(self, key, *items):
|
|
self.parent.lpush(key, *items)
|
|
return self
|
|
|
|
def expire(self, key, ttl):
|
|
raise SystemExit()
|
|
|
|
def execute(self):
|
|
return None
|
|
|
|
client.pipeline = lambda: _Pipeline(client) # type: ignore[assignment]
|
|
|
|
with self.assertRaises(SystemExit):
|
|
writer.run()
|
|
|
|
def test_result_writer_run_groups_batches(self) -> None:
|
|
client = sys.modules["redis"].Redis()
|
|
writer = self.module.ResultWriter(client, idx=0, batch=10, ttl=5)
|
|
|
|
with writer.cond:
|
|
writer.data.appendleft(("k1", b"a"))
|
|
writer.data.appendleft(("k1", b"b"))
|
|
writer.data.appendleft(("k2", b"c"))
|
|
|
|
def _pipeline():
|
|
class _P:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return None
|
|
|
|
def multi(self):
|
|
return self
|
|
|
|
def lpush(self, key, *items):
|
|
client.lpush(key, *items)
|
|
return self
|
|
|
|
def expire(self, key, ttl):
|
|
raise SystemExit()
|
|
|
|
def execute(self):
|
|
return None
|
|
|
|
return _P()
|
|
|
|
client.pipeline = _pipeline # type: ignore[assignment]
|
|
with self.assertRaises(SystemExit):
|
|
writer.run()
|
|
|
|
def test_infer_scheduler_routine_report(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
def _fake_hset(*_args, **_kwargs):
|
|
raise ValueError("fail")
|
|
|
|
infer.client.hset = _fake_hset # type: ignore[assignment]
|
|
original_logger = self.module.logger.error
|
|
self.module.logger.error = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit())
|
|
|
|
try:
|
|
with self.assertRaises(SystemExit):
|
|
infer.routine_report()
|
|
finally:
|
|
self.module.logger.error = original_logger
|
|
|
|
def test_infer_scheduler_loop_expire_reqs(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
def _raise_exit(ttl):
|
|
raise SystemExit()
|
|
|
|
infer.node.expire_reqs = _raise_exit # type: ignore[assignment]
|
|
|
|
with self.assertRaises(SystemExit):
|
|
infer.loop_expire_reqs()
|
|
|
|
def test_infer_scheduler_loop_get_reqs(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=10,
|
|
)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
|
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
|
|
|
|
req = self.module.Request("rq", prompt_token_ids_len=3)
|
|
payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5)
|
|
key = f"ReqQ_{infer.nodeid}"
|
|
infer.client.storage[key] = [payload]
|
|
|
|
state = {"called": False}
|
|
|
|
def _fake_rpop(k, batch):
|
|
if not state["called"]:
|
|
state["called"] = True
|
|
return infer.client.storage[k][:]
|
|
raise SystemExit()
|
|
|
|
infer.client.rpop = _fake_rpop # type: ignore[assignment]
|
|
infer.client.brpop = lambda *_args, **_kwargs: None # type: ignore[assignment]
|
|
|
|
with self.assertRaises(SystemExit):
|
|
infer.loop_get_reqs()
|
|
|
|
def test_infer_scheduler_get_requests_limits(self) -> None:
|
|
config = self.module.SplitWiseSchedulerConfig(
|
|
enable_chunked_prefill=True,
|
|
max_num_partial_prefills=1,
|
|
max_long_partial_prefills=1,
|
|
max_model_len=50,
|
|
)
|
|
infer = self.module.InferScheduler(config)
|
|
infer.role = "prefill"
|
|
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
|
|
|
heavy = self.module.Request("heavy", prompt_token_ids_len=10)
|
|
infer.reqs_queue.append(heavy)
|
|
picked = infer.get_requests(
|
|
available_blocks=1,
|
|
block_size=4,
|
|
reserved_output_blocks=1,
|
|
max_num_batched_tokens=100,
|
|
batch=1,
|
|
)
|
|
self.assertEqual(picked, [])
|
|
|
|
infer.reqs_queue.clear()
|
|
infer.reqs_queue.append(
|
|
self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10)
|
|
)
|
|
infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2))
|
|
selected = infer.get_requests(
|
|
available_blocks=100,
|
|
block_size=4,
|
|
reserved_output_blocks=1,
|
|
max_num_batched_tokens=100,
|
|
batch=2,
|
|
)
|
|
self.assertEqual(len(selected), 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument("--print-coverage-command", action="store_true")
|
|
known_args, remaining = parser.parse_known_args()
|
|
|
|
if known_args.print_coverage_command:
|
|
print("python -m coverage run -m unittest tests.scheduler.test_splitwise_scheduler")
|
|
print("python -m coverage report -m --include='fastdeploy/scheduler/splitwise_scheduler.py'")
|
|
|
|
unittest.main(argv=[sys.argv[0]] + remaining)
|