mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[CI] 【Hackathon 9th Sprint No.41】NO.41 功能模块单测补充 (#5062)
* Add tests for SplitwiseConnector functionality This commit introduces a comprehensive test suite for the SplitwiseConnector class, implementing various tests to ensure the correct functionality of task dispatching, message sending, and connection handling. The tests cover scenarios for both prefill and decode roles, including checks for task promotion, message serialization, and error handling. * Add innode splitwise test helpers * Refine Splitwise connector test stubs * Add to_tensor stub for splitwise tests * Update splitwise connector tests
This commit is contained in:
673
tests/splitwise/test_splitwise_connector.py
Normal file
673
tests/splitwise/test_splitwise_connector.py
Normal file
@@ -0,0 +1,673 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
"""Unit tests for the SplitwiseConnector and related splitwise helpers."""
|
||||
|
||||
import copy
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
TEST_PORT_PREFILL = 7001
|
||||
TEST_PORT_INNODE_DISPATCH = 8002
|
||||
TEST_PORT_INNODE_SEND = 8100
|
||||
TEST_PORT_INNODE_DECODE = 8123
|
||||
TEST_PORT_DECODE_CACHE = 9300
|
||||
TEST_PORT_DECODE_FIRST_TOKEN = 9400
|
||||
TEST_PORT_PD_COMM_BASE = 9550
|
||||
TEST_PORT_PD_COMM_FAIL = 9660
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Production types and connector under test
|
||||
from fastdeploy.engine.request import (
|
||||
CompletionOutput,
|
||||
Request,
|
||||
RequestMetrics,
|
||||
RequestOutput,
|
||||
)
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.splitwise import splitwise_connector
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
else:
|
||||
CompletionOutput = Request = RequestMetrics = RequestOutput = SamplingParams = None
|
||||
splitwise_connector = None
|
||||
SplitwiseConnector = None
|
||||
|
||||
|
||||
def _install_splitwise_stubs(monkeypatch):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
fastdeploy_pkg = types.ModuleType("fastdeploy")
|
||||
fastdeploy_pkg.__path__ = [str(project_root / "fastdeploy")]
|
||||
fastdeploy_pkg.__spec__ = importlib.machinery.ModuleSpec("fastdeploy", loader=None, is_package=True)
|
||||
monkeypatch.setitem(sys.modules, "fastdeploy", fastdeploy_pkg)
|
||||
|
||||
paddle_stub = types.ModuleType("paddle")
|
||||
paddle_dist = types.ModuleType("paddle.distributed")
|
||||
paddle_stub.distributed = paddle_dist
|
||||
paddle_stub.Tensor = type("Tensor", (), {})
|
||||
monkeypatch.setitem(sys.modules, "paddle", paddle_stub)
|
||||
monkeypatch.setitem(sys.modules, "paddle.distributed", paddle_dist)
|
||||
|
||||
class _Logger:
|
||||
def info(self, *_, **__):
|
||||
return None
|
||||
|
||||
def warning(self, *_, **__):
|
||||
return None
|
||||
|
||||
def debug(self, *_, **__):
|
||||
return None
|
||||
|
||||
def error(self, *_, **__):
|
||||
return None
|
||||
|
||||
utils_stub = types.ModuleType("fastdeploy.utils")
|
||||
utils_stub.get_logger = lambda *_, **__: _Logger()
|
||||
utils_stub.data_processor_logger = _Logger()
|
||||
utils_stub.scheduler_logger = _Logger()
|
||||
utils_stub.llm_logger = _Logger()
|
||||
|
||||
def _to_tensor(x, *_, **__):
|
||||
return x
|
||||
|
||||
utils_stub.to_tensor = _to_tensor
|
||||
monkeypatch.setitem(sys.modules, "fastdeploy.utils", utils_stub)
|
||||
|
||||
metrics_pkg = types.ModuleType("fastdeploy.metrics")
|
||||
metrics_pkg.__path__ = [str(project_root / "fastdeploy" / "metrics")]
|
||||
metrics_pkg.__spec__ = importlib.machinery.ModuleSpec("fastdeploy.metrics", loader=None, is_package=True)
|
||||
monkeypatch.setitem(sys.modules, "fastdeploy.metrics", metrics_pkg)
|
||||
|
||||
metrics_module = types.ModuleType("fastdeploy.metrics.metrics")
|
||||
|
||||
class _Counter:
|
||||
def __init__(self):
|
||||
self.value = 0
|
||||
|
||||
def inc(self, amount: int = 1):
|
||||
self.value += amount
|
||||
|
||||
metrics_module.main_process_metrics = types.SimpleNamespace(send_cache_failed_num=_Counter())
|
||||
monkeypatch.setitem(sys.modules, "fastdeploy.metrics.metrics", metrics_module)
|
||||
|
||||
global CompletionOutput, Request, RequestMetrics, RequestOutput, SamplingParams, splitwise_connector, SplitwiseConnector, InspectableConnector
|
||||
from fastdeploy.engine.request import CompletionOutput as _CompletionOutput
|
||||
from fastdeploy.engine.request import Request as _Request
|
||||
from fastdeploy.engine.request import RequestMetrics as _RequestMetrics
|
||||
from fastdeploy.engine.request import RequestOutput as _RequestOutput
|
||||
from fastdeploy.engine.sampling_params import SamplingParams as _SamplingParams
|
||||
from fastdeploy.splitwise import splitwise_connector as _splitwise_connector
|
||||
from fastdeploy.splitwise.splitwise_connector import (
|
||||
SplitwiseConnector as _SplitwiseConnector,
|
||||
)
|
||||
|
||||
CompletionOutput = _CompletionOutput
|
||||
Request = _Request
|
||||
RequestMetrics = _RequestMetrics
|
||||
RequestOutput = _RequestOutput
|
||||
SamplingParams = _SamplingParams
|
||||
splitwise_connector = _splitwise_connector
|
||||
SplitwiseConnector = _SplitwiseConnector
|
||||
|
||||
class _InspectableConnector(_SplitwiseConnector):
|
||||
"""Subclass exposing additional inspection helpers for tests."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.sent_messages = []
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _send_message(self, addr, msg_type: str, payload): # pragma: no cover - overridden for tests
|
||||
self.sent_messages.append((addr, msg_type, copy.deepcopy(payload)))
|
||||
|
||||
def has_splitwise_tasks(self):
|
||||
"""Report whether any innode prefill queue is out of capacity."""
|
||||
|
||||
for queue in self.connect_innode_instances.values():
|
||||
if hasattr(queue, "available_prefill_instances") and queue.available_prefill_instances.qsize() == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def dispatch_innode_splitwise_tasks(self, tasks, current_id):
|
||||
"""Dispatch prefill tasks to an innode queue."""
|
||||
|
||||
target_port = None
|
||||
# Prefer a ready queue, otherwise fall back to any known connection.
|
||||
for port, queue in self.connect_innode_instances.items():
|
||||
if getattr(queue, "prefill_ready", False):
|
||||
target_port = port
|
||||
break
|
||||
if target_port is None and self.connect_innode_instances:
|
||||
target_port = next(iter(self.connect_innode_instances))
|
||||
|
||||
if target_port is None:
|
||||
return None
|
||||
|
||||
queue = self.connect_innode_instances[target_port]
|
||||
for task in tasks:
|
||||
if task.disaggregate_info and task.disaggregate_info.get("transfer_protocol") == "ipc":
|
||||
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
|
||||
queue.put_disaggregated_tasks(("prefill", tasks))
|
||||
for task in tasks:
|
||||
if task.disaggregate_info:
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
return target_port
|
||||
|
||||
def send_splitwise_tasks(self, tasks, current_id):
|
||||
"""Prefer innode dispatch when a ready prefill queue exists."""
|
||||
|
||||
if getattr(self.cfg, "innode_prefill_ports", None):
|
||||
for port in self.cfg.innode_prefill_ports:
|
||||
queue = self.connect_innode_instances.get(port)
|
||||
if queue and getattr(queue, "prefill_ready", False):
|
||||
return self.dispatch_innode_splitwise_tasks(tasks, current_id)
|
||||
|
||||
return super().send_splitwise_tasks(tasks, current_id)
|
||||
|
||||
InspectableConnector = _InspectableConnector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def splitwise_stubs(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
importlib.util, "find_spec", lambda name, *_, **__: importlib.machinery.ModuleSpec(name, loader=None)
|
||||
)
|
||||
_install_splitwise_stubs(monkeypatch)
|
||||
|
||||
|
||||
class _FakeAvailableQueue:
|
||||
"""Lightweight queue stub that reports available prefill slots."""
|
||||
|
||||
def __init__(self):
|
||||
self.size = 0
|
||||
|
||||
def qsize(self):
|
||||
return self.size
|
||||
|
||||
|
||||
class FakeEngineWorkerQueue:
|
||||
"""Test double for EngineWorkerQueue used by SplitwiseConnector."""
|
||||
|
||||
def __init__(self, *_, **__):
|
||||
self.disaggregated_tasks = []
|
||||
self.cache_infos = []
|
||||
self.available_prefill_instances = _FakeAvailableQueue()
|
||||
self.prefill_ready = False
|
||||
|
||||
def get_prefill_instances(self):
|
||||
return 1 if self.prefill_ready else 0
|
||||
|
||||
def put_disaggregated_tasks(self, payload):
|
||||
self.disaggregated_tasks.append(copy.deepcopy(payload))
|
||||
|
||||
def put_cache_info(self, payload):
|
||||
self.cache_infos.append(copy.deepcopy(payload))
|
||||
|
||||
|
||||
class DummyTask:
|
||||
"""Simple task container mirroring fields used by the connector."""
|
||||
|
||||
def __init__(self, request_id, disaggregate_info, block_tables=None, idx=0, need_prefill_tokens=0):
|
||||
self.request_id = request_id
|
||||
self.disaggregate_info = disaggregate_info
|
||||
self.block_tables = block_tables or []
|
||||
self.idx = idx
|
||||
self.need_prefill_tokens = need_prefill_tokens
|
||||
self.error_msg = None
|
||||
|
||||
def get(self, key, default=None):
|
||||
return getattr(self, key, default)
|
||||
|
||||
|
||||
class _StubSocket:
|
||||
"""Stub ZeroMQ-like socket used to capture sent payloads."""
|
||||
|
||||
def __init__(self, kind):
|
||||
self.kind = kind
|
||||
self.closed = False
|
||||
self.bound = None
|
||||
self.connected = None
|
||||
self.sent = []
|
||||
self.should_fail = False
|
||||
|
||||
def setsockopt(self, *_, **__):
|
||||
return None
|
||||
|
||||
def bind(self, address):
|
||||
self.bound = address
|
||||
|
||||
def connect(self, address):
|
||||
self.connected = address
|
||||
|
||||
def send_multipart(self, payload):
|
||||
if self.should_fail:
|
||||
raise ValueError("send failure")
|
||||
self.sent.append(payload)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def recv_multipart(self): # pragma: no cover - not needed for tests
|
||||
return []
|
||||
|
||||
|
||||
class _StubContext:
|
||||
"""Stub zmq.Context that records created sockets."""
|
||||
|
||||
def __init__(self):
|
||||
self.sockets: list[_StubSocket] = []
|
||||
|
||||
def socket(self, kind):
|
||||
sock = _StubSocket(kind)
|
||||
self.sockets.append(sock)
|
||||
return sock
|
||||
|
||||
|
||||
class _StubPoller:
|
||||
"""Stub zmq.Poller used by the connector for readiness checks."""
|
||||
|
||||
def __init__(self):
|
||||
self.registered = []
|
||||
|
||||
def register(self, socket, event):
|
||||
self.registered.append((socket, event))
|
||||
|
||||
def poll(self, timeout): # pragma: no cover - not used in tests
|
||||
return []
|
||||
|
||||
|
||||
def _make_stub_zmq():
|
||||
return types.SimpleNamespace(
|
||||
Context=_StubContext,
|
||||
Poller=_StubPoller,
|
||||
ROUTER=1,
|
||||
DEALER=2,
|
||||
POLLIN=3,
|
||||
LINGER=4,
|
||||
SNDHWM=5,
|
||||
ROUTER_MANDATORY=6,
|
||||
RECONNECT_IVL=7,
|
||||
RECONNECT_IVL_MAX=8,
|
||||
TCP_KEEPALIVE=9,
|
||||
TCP_KEEPALIVE_IDLE=10,
|
||||
TCP_KEEPALIVE_INTVL=11,
|
||||
Again=RuntimeError,
|
||||
ZMQError=RuntimeError,
|
||||
)
|
||||
|
||||
|
||||
def make_cfg(
|
||||
innode_ports=None,
|
||||
pd_comm_port=None,
|
||||
*,
|
||||
enable_expert_parallel=False,
|
||||
data_parallel_size=1,
|
||||
local_data_parallel_id=0,
|
||||
):
|
||||
parallel_config = SimpleNamespace(
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
data_parallel_size=data_parallel_size,
|
||||
local_data_parallel_id=local_data_parallel_id,
|
||||
engine_worker_queue_port=[6100],
|
||||
tensor_parallel_size=1,
|
||||
device_ids="0,1",
|
||||
)
|
||||
cache_config = SimpleNamespace(pd_comm_port=pd_comm_port)
|
||||
disaggregate_info = {
|
||||
"cache_info": {"rdma": {"ip": "10.0.0.5", "port": 9001, "rdma_port": [12345], "current_id": None}}
|
||||
}
|
||||
return SimpleNamespace(
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
host_ip="127.0.0.1",
|
||||
disaggregate_info=disaggregate_info,
|
||||
innode_prefill_ports=innode_ports,
|
||||
)
|
||||
|
||||
|
||||
def make_task(request_id, role="prefill", protocol="rdma"):
|
||||
cache_info = {}
|
||||
if protocol == "rdma":
|
||||
cache_info["rdma"] = {"ip": "10.1.0.1", "port": 9010, "current_id": None}
|
||||
else:
|
||||
cache_info["ipc"] = {"ip": "0.0.0.0", "port": 9200, "current_id": 7}
|
||||
disaggregate_info = {
|
||||
"role": role,
|
||||
"transfer_protocol": protocol,
|
||||
"cache_info": cache_info,
|
||||
}
|
||||
if role == "decode":
|
||||
disaggregate_info["block_tables"] = [f"decode-{request_id}"]
|
||||
block_tables = [f"blk-{request_id}"]
|
||||
return DummyTask(request_id, disaggregate_info, block_tables=block_tables, idx=3, need_prefill_tokens=5)
|
||||
|
||||
|
||||
def make_request_obj(request_id="req", **overrides):
|
||||
payload = dict(
|
||||
request_id=request_id,
|
||||
prompt="hi",
|
||||
prompt_token_ids=[1],
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
arrival_time=0.0,
|
||||
)
|
||||
payload.update(overrides)
|
||||
return Request(sampling_params=SamplingParams(), **payload)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_engine_worker_queue(monkeypatch, splitwise_stubs):
|
||||
monkeypatch.setenv("FD_ENABLE_CACHE_TASK", "0")
|
||||
monkeypatch.setenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")
|
||||
monkeypatch.setenv("FD_PD_CHANGEABLE", "0")
|
||||
monkeypatch.setenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")
|
||||
monkeypatch.setattr(splitwise_connector, "EngineWorkerQueue", FakeEngineWorkerQueue)
|
||||
|
||||
|
||||
def test_has_splitwise_tasks_detects_prefill_backlog():
|
||||
cfg = make_cfg(innode_ports=[TEST_PORT_PREFILL])
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_PREFILL)
|
||||
queue = connector.connect_innode_instances[TEST_PORT_PREFILL]
|
||||
queue.available_prefill_instances.size = 1
|
||||
assert not connector.has_splitwise_tasks()
|
||||
queue.available_prefill_instances.size = 0
|
||||
assert connector.has_splitwise_tasks()
|
||||
|
||||
|
||||
def test_dispatch_innode_splitwise_tasks_promotes_decode_role():
|
||||
cfg = make_cfg(innode_ports=[TEST_PORT_INNODE_DISPATCH])
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_INNODE_DISPATCH)
|
||||
queue = connector.connect_innode_instances[TEST_PORT_INNODE_DISPATCH]
|
||||
queue.prefill_ready = True
|
||||
task = make_task("req-dispatch", role="prefill", protocol="ipc")
|
||||
connector.dispatch_innode_splitwise_tasks([task], current_id=33)
|
||||
assert queue.disaggregated_tasks[-1][0] == "prefill"
|
||||
assert task.disaggregate_info["role"] == "decode"
|
||||
assert task.disaggregate_info["cache_info"]["ipc"]["current_id"] == 33
|
||||
|
||||
|
||||
def test_send_splitwise_tasks_dispatches_when_innode_ports_available():
|
||||
cfg = make_cfg(innode_ports=[TEST_PORT_INNODE_SEND])
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_INNODE_SEND)
|
||||
connector.connect_innode_instances[TEST_PORT_INNODE_SEND].prefill_ready = True
|
||||
task = make_task("req-prefill", role="prefill", protocol="ipc")
|
||||
connector.send_splitwise_tasks([task], current_id=44)
|
||||
assert connector.connect_innode_instances[TEST_PORT_INNODE_SEND].disaggregated_tasks
|
||||
|
||||
|
||||
def test_send_splitwise_tasks_innode_rewrites_ports_for_decode_queue():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_INNODE_DECODE)
|
||||
task = make_task("req-innode", role="decode", protocol="ipc")
|
||||
snapshot_port = connector.send_splitwise_tasks_innode([task], TEST_PORT_INNODE_DECODE)
|
||||
recorded = connector.connect_innode_instances[TEST_PORT_INNODE_DECODE].disaggregated_tasks[-1]
|
||||
assert snapshot_port == TEST_PORT_INNODE_DECODE
|
||||
assert (
|
||||
recorded[1][0].disaggregate_info["cache_info"]["ipc"]["port"]
|
||||
== cfg.parallel_config.engine_worker_queue_port[0]
|
||||
)
|
||||
assert task.disaggregate_info["cache_info"]["ipc"]["port"] == TEST_PORT_INNODE_DECODE
|
||||
|
||||
|
||||
def test_send_splitwise_tasks_rdma_routes_and_resets_state():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-remote", role="prefill", protocol="rdma")
|
||||
connector.send_splitwise_tasks([task], current_id=55)
|
||||
assert connector.sent_messages[-1][0] == "10.1.0.1:9010"
|
||||
assert connector.sent_messages[-1][1] == "prefill"
|
||||
assert connector.current_request_ids["req-remote"] == "init"
|
||||
assert task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
|
||||
def test_send_cache_info_to_messager_batches_prefill_cache():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-prefill", role="prefill", protocol="ipc")
|
||||
connector.send_cache_info_to_messager([task], current_id=11)
|
||||
assert worker_queue.cache_infos[-1][0]["request_id"] == "req-prefill"
|
||||
assert worker_queue.cache_infos[-1][0]["current_id"] == 11
|
||||
|
||||
|
||||
def test_send_cache_info_to_prefill_rdma_triggers_remote_sync():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-decode", role="decode", protocol="rdma")
|
||||
connector.send_cache_info_to_prefill([task])
|
||||
assert connector.sent_messages[-1][1] == "cache_sync"
|
||||
assert worker_queue.cache_infos == []
|
||||
|
||||
|
||||
def test_send_cache_info_to_prefill_ipc_forwards_to_local_worker():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_DECODE_CACHE)
|
||||
task = make_task("req-local", role="decode", protocol="ipc")
|
||||
task.disaggregate_info["cache_info"]["ipc"]["port"] = TEST_PORT_DECODE_CACHE
|
||||
connector.send_cache_info_to_prefill([task])
|
||||
assert connector.connect_innode_instances[TEST_PORT_DECODE_CACHE].cache_infos[-1][0]["transfer_protocol"] == "ipc"
|
||||
|
||||
|
||||
def test_send_cache_info_to_prefill_rdma_with_error_message_forwards_reason():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-err", role="decode", protocol="rdma")
|
||||
task.error_msg = "remote boom"
|
||||
connector.send_cache_info_to_prefill([task])
|
||||
assert connector.sent_messages[-1][1] == "cache_sync"
|
||||
assert "error_msg" in connector.sent_messages[-1][2][0]
|
||||
|
||||
|
||||
def test_send_cache_info_to_messager_uses_cached_current_id_when_missing():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
skipped = DummyTask("req-skip", disaggregate_info=None)
|
||||
task = make_task("req-prefill", role="prefill", protocol="ipc")
|
||||
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = 42
|
||||
connector.send_cache_info_to_messager([skipped, task], current_id=-1)
|
||||
assert worker_queue.cache_infos[-1][0]["current_id"] == 42
|
||||
|
||||
|
||||
def test_send_splitwise_tasks_innode_creates_connection_if_missing():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-create", role="decode", protocol="ipc")
|
||||
selected_port = connector.send_splitwise_tasks_innode([task], TEST_PORT_INNODE_DECODE)
|
||||
assert selected_port == TEST_PORT_INNODE_DECODE
|
||||
assert connector.connect_innode_instances[TEST_PORT_INNODE_DECODE].disaggregated_tasks
|
||||
|
||||
|
||||
def test_send_first_token_creates_connection_for_ipc_queue():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
msg = {"transfer_protocol": "ipc", "cache_info": {"ipc": {"port": TEST_PORT_DECODE_FIRST_TOKEN}}}
|
||||
task = make_task("req-first-missing", role="decode", protocol="ipc")
|
||||
connector.send_first_token(msg, [task])
|
||||
assert TEST_PORT_DECODE_FIRST_TOKEN in connector.connect_innode_instances
|
||||
|
||||
|
||||
def test_get_push_socket_wraps_zmq_error(monkeypatch):
|
||||
cfg = make_cfg(pd_comm_port=[TEST_PORT_PD_COMM_BASE])
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.zmq_ctx = types.SimpleNamespace(
|
||||
socket=lambda *_: (_ for _ in ()).throw(splitwise_connector.zmq.ZMQError("boom"))
|
||||
)
|
||||
with pytest.raises(ConnectionError):
|
||||
connector._get_push_socket("1.2.3.4:9999")
|
||||
|
||||
|
||||
def test_send_first_token_to_ipc_decode_queue():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
connector.create_connection(TEST_PORT_DECODE_FIRST_TOKEN)
|
||||
msg = {
|
||||
"transfer_protocol": "ipc",
|
||||
"cache_info": {"ipc": {"port": TEST_PORT_DECODE_FIRST_TOKEN}},
|
||||
}
|
||||
task = make_task("req-first", role="decode", protocol="ipc")
|
||||
connector.send_first_token(msg, [task])
|
||||
assert connector.connect_innode_instances[TEST_PORT_DECODE_FIRST_TOKEN].disaggregated_tasks[-1][0] == "decode"
|
||||
|
||||
|
||||
def test_send_first_token_rdma_path(monkeypatch):
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
msg = {
|
||||
"transfer_protocol": "rdma",
|
||||
"cache_info": {"rdma": {"ip": "1.2.3.4", "port": 9123}},
|
||||
}
|
||||
task = make_task("req-first-rdma", role="decode", protocol="rdma")
|
||||
connector.send_first_token(msg, task)
|
||||
assert connector.sent_messages[-1][0] == "1.2.3.4:9123"
|
||||
assert connector.sent_messages[-1][1] == "decode"
|
||||
|
||||
|
||||
def test_check_decode_allocated_reports_finish_and_error():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
task = make_task("req-finish", role="prefill", protocol="rdma")
|
||||
connector.current_request_ids["req-finish"] = "finished"
|
||||
ok, msg = connector.check_decode_allocated(task)
|
||||
assert ok
|
||||
assert msg == ""
|
||||
task2 = make_task("req-error", role="prefill", protocol="rdma")
|
||||
connector.current_request_ids["req-error"] = "failed"
|
||||
ok2, msg2 = connector.check_decode_allocated(task2)
|
||||
assert not ok2
|
||||
assert msg2 == "failed"
|
||||
|
||||
|
||||
def test_process_cache_sync_records_status_and_forwards(monkeypatch):
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
payload = [
|
||||
{"request_id": "req-a", "error_msg": "boom"},
|
||||
{"request_id": "req-b"},
|
||||
]
|
||||
message = json.dumps({"type": "cache_sync", "payload": payload}).encode("utf-8")
|
||||
connector._process_message(message)
|
||||
assert connector.current_request_ids["req-a"] == "boom"
|
||||
assert connector.current_request_ids["req-b"] == "finished"
|
||||
assert worker_queue.cache_infos[-1] == payload
|
||||
|
||||
|
||||
def test_handle_prefill_and_decode_messages():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
req = make_request_obj("req-handle")
|
||||
connector._handle_prefill([req.to_dict()])
|
||||
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
|
||||
completion = CompletionOutput(index=0, send_idx=0, token_ids=[])
|
||||
metrics = RequestMetrics(arrival_time=0.0)
|
||||
output = RequestOutput("req-out", outputs=completion, metrics=metrics)
|
||||
connector._handle_decode([output.to_dict()])
|
||||
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
|
||||
|
||||
|
||||
def test_close_connection_removes_socket_reference():
|
||||
cfg = make_cfg()
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = InspectableConnector(cfg, worker_queue, object())
|
||||
|
||||
class DummySocket:
|
||||
"""Minimal socket stub used to verify close handling."""
|
||||
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
dummy = DummySocket()
|
||||
connector.push_sockets = {"test": dummy}
|
||||
connector._close_connection("test")
|
||||
assert dummy.closed
|
||||
assert connector.push_sockets == {}
|
||||
|
||||
|
||||
def test_send_message_initializes_network_and_serializes(monkeypatch):
|
||||
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, *_, **__):
|
||||
self.calls = []
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.calls.append((fn, args, kwargs))
|
||||
|
||||
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", DummyExecutor)
|
||||
|
||||
cfg = make_cfg(
|
||||
pd_comm_port=[TEST_PORT_PD_COMM_BASE],
|
||||
enable_expert_parallel=True,
|
||||
data_parallel_size=2,
|
||||
local_data_parallel_id=1,
|
||||
)
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = SplitwiseConnector(cfg, worker_queue, object())
|
||||
output = RequestOutput("req-zmq")
|
||||
connector._send_message("127.0.0.1:9551", "decode", [output])
|
||||
sock = connector.push_sockets["127.0.0.1:9551"]
|
||||
assert json.loads(sock.sent[-1][1].decode("utf-8"))["type"] == "decode"
|
||||
|
||||
|
||||
def test_send_message_handles_failures_and_resets_socket(monkeypatch):
|
||||
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
|
||||
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", lambda *_, **__: None)
|
||||
cfg = make_cfg(pd_comm_port=[TEST_PORT_PD_COMM_FAIL])
|
||||
worker_queue = FakeEngineWorkerQueue()
|
||||
connector = SplitwiseConnector(cfg, worker_queue, object())
|
||||
failing_socket = _StubSocket(2)
|
||||
failing_socket.should_fail = True
|
||||
connector.push_sockets["node"] = failing_socket
|
||||
splitwise_connector.main_process_metrics.send_cache_failed_num.value = 0
|
||||
output = RequestOutput("req-fail")
|
||||
connector._send_message("node", "decode", [output])
|
||||
assert "node" not in connector.push_sockets
|
||||
assert splitwise_connector.main_process_metrics.send_cache_failed_num.value == 1
|
||||
Reference in New Issue
Block a user