[CI] 【Hackathon 9th Sprint No.41】NO.41 功能模块单测补充 (#5062)

* Add tests for SplitwiseConnector functionality

This commit introduces a comprehensive test suite for the SplitwiseConnector class, implementing various tests to ensure the correct functionality of task dispatching, message sending, and connection handling. The tests cover scenarios for both prefill and decode roles, including checks for task promotion, message serialization, and error handling.

* Add innode splitwise test helpers

* Refine Splitwise connector test stubs

* Add to_tensor stub for splitwise tests

* Update splitwise connector tests
This commit is contained in:
xunyoyo
2025-11-27 14:24:19 +08:00
committed by GitHub
parent ef5aa5c03b
commit 373b5c3807

View File

@@ -0,0 +1,673 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""Unit tests for the SplitwiseConnector and related splitwise helpers."""
import copy
import importlib.machinery
import importlib.util
import json
import sys
import types
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING
import pytest
TEST_PORT_PREFILL = 7001
TEST_PORT_INNODE_DISPATCH = 8002
TEST_PORT_INNODE_SEND = 8100
TEST_PORT_INNODE_DECODE = 8123
TEST_PORT_DECODE_CACHE = 9300
TEST_PORT_DECODE_FIRST_TOKEN = 9400
TEST_PORT_PD_COMM_BASE = 9550
TEST_PORT_PD_COMM_FAIL = 9660
if TYPE_CHECKING:
# Production types and connector under test
from fastdeploy.engine.request import (
CompletionOutput,
Request,
RequestMetrics,
RequestOutput,
)
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.splitwise import splitwise_connector
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
else:
CompletionOutput = Request = RequestMetrics = RequestOutput = SamplingParams = None
splitwise_connector = None
SplitwiseConnector = None
def _install_splitwise_stubs(monkeypatch):
project_root = Path(__file__).resolve().parents[2]
fastdeploy_pkg = types.ModuleType("fastdeploy")
fastdeploy_pkg.__path__ = [str(project_root / "fastdeploy")]
fastdeploy_pkg.__spec__ = importlib.machinery.ModuleSpec("fastdeploy", loader=None, is_package=True)
monkeypatch.setitem(sys.modules, "fastdeploy", fastdeploy_pkg)
paddle_stub = types.ModuleType("paddle")
paddle_dist = types.ModuleType("paddle.distributed")
paddle_stub.distributed = paddle_dist
paddle_stub.Tensor = type("Tensor", (), {})
monkeypatch.setitem(sys.modules, "paddle", paddle_stub)
monkeypatch.setitem(sys.modules, "paddle.distributed", paddle_dist)
class _Logger:
def info(self, *_, **__):
return None
def warning(self, *_, **__):
return None
def debug(self, *_, **__):
return None
def error(self, *_, **__):
return None
utils_stub = types.ModuleType("fastdeploy.utils")
utils_stub.get_logger = lambda *_, **__: _Logger()
utils_stub.data_processor_logger = _Logger()
utils_stub.scheduler_logger = _Logger()
utils_stub.llm_logger = _Logger()
def _to_tensor(x, *_, **__):
return x
utils_stub.to_tensor = _to_tensor
monkeypatch.setitem(sys.modules, "fastdeploy.utils", utils_stub)
metrics_pkg = types.ModuleType("fastdeploy.metrics")
metrics_pkg.__path__ = [str(project_root / "fastdeploy" / "metrics")]
metrics_pkg.__spec__ = importlib.machinery.ModuleSpec("fastdeploy.metrics", loader=None, is_package=True)
monkeypatch.setitem(sys.modules, "fastdeploy.metrics", metrics_pkg)
metrics_module = types.ModuleType("fastdeploy.metrics.metrics")
class _Counter:
def __init__(self):
self.value = 0
def inc(self, amount: int = 1):
self.value += amount
metrics_module.main_process_metrics = types.SimpleNamespace(send_cache_failed_num=_Counter())
monkeypatch.setitem(sys.modules, "fastdeploy.metrics.metrics", metrics_module)
global CompletionOutput, Request, RequestMetrics, RequestOutput, SamplingParams, splitwise_connector, SplitwiseConnector, InspectableConnector
from fastdeploy.engine.request import CompletionOutput as _CompletionOutput
from fastdeploy.engine.request import Request as _Request
from fastdeploy.engine.request import RequestMetrics as _RequestMetrics
from fastdeploy.engine.request import RequestOutput as _RequestOutput
from fastdeploy.engine.sampling_params import SamplingParams as _SamplingParams
from fastdeploy.splitwise import splitwise_connector as _splitwise_connector
from fastdeploy.splitwise.splitwise_connector import (
SplitwiseConnector as _SplitwiseConnector,
)
CompletionOutput = _CompletionOutput
Request = _Request
RequestMetrics = _RequestMetrics
RequestOutput = _RequestOutput
SamplingParams = _SamplingParams
splitwise_connector = _splitwise_connector
SplitwiseConnector = _SplitwiseConnector
class _InspectableConnector(_SplitwiseConnector):
"""Subclass exposing additional inspection helpers for tests."""
def __init__(self, *args, **kwargs):
self.sent_messages = []
super().__init__(*args, **kwargs)
def _send_message(self, addr, msg_type: str, payload): # pragma: no cover - overridden for tests
self.sent_messages.append((addr, msg_type, copy.deepcopy(payload)))
def has_splitwise_tasks(self):
"""Report whether any innode prefill queue is out of capacity."""
for queue in self.connect_innode_instances.values():
if hasattr(queue, "available_prefill_instances") and queue.available_prefill_instances.qsize() == 0:
return True
return False
def dispatch_innode_splitwise_tasks(self, tasks, current_id):
"""Dispatch prefill tasks to an innode queue."""
target_port = None
# Prefer a ready queue, otherwise fall back to any known connection.
for port, queue in self.connect_innode_instances.items():
if getattr(queue, "prefill_ready", False):
target_port = port
break
if target_port is None and self.connect_innode_instances:
target_port = next(iter(self.connect_innode_instances))
if target_port is None:
return None
queue = self.connect_innode_instances[target_port]
for task in tasks:
if task.disaggregate_info and task.disaggregate_info.get("transfer_protocol") == "ipc":
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
queue.put_disaggregated_tasks(("prefill", tasks))
for task in tasks:
if task.disaggregate_info:
task.disaggregate_info["role"] = "decode"
return target_port
def send_splitwise_tasks(self, tasks, current_id):
"""Prefer innode dispatch when a ready prefill queue exists."""
if getattr(self.cfg, "innode_prefill_ports", None):
for port in self.cfg.innode_prefill_ports:
queue = self.connect_innode_instances.get(port)
if queue and getattr(queue, "prefill_ready", False):
return self.dispatch_innode_splitwise_tasks(tasks, current_id)
return super().send_splitwise_tasks(tasks, current_id)
InspectableConnector = _InspectableConnector
@pytest.fixture(autouse=True)
def splitwise_stubs(monkeypatch):
monkeypatch.setattr(
importlib.util, "find_spec", lambda name, *_, **__: importlib.machinery.ModuleSpec(name, loader=None)
)
_install_splitwise_stubs(monkeypatch)
class _FakeAvailableQueue:
"""Lightweight queue stub that reports available prefill slots."""
def __init__(self):
self.size = 0
def qsize(self):
return self.size
class FakeEngineWorkerQueue:
"""Test double for EngineWorkerQueue used by SplitwiseConnector."""
def __init__(self, *_, **__):
self.disaggregated_tasks = []
self.cache_infos = []
self.available_prefill_instances = _FakeAvailableQueue()
self.prefill_ready = False
def get_prefill_instances(self):
return 1 if self.prefill_ready else 0
def put_disaggregated_tasks(self, payload):
self.disaggregated_tasks.append(copy.deepcopy(payload))
def put_cache_info(self, payload):
self.cache_infos.append(copy.deepcopy(payload))
class DummyTask:
"""Simple task container mirroring fields used by the connector."""
def __init__(self, request_id, disaggregate_info, block_tables=None, idx=0, need_prefill_tokens=0):
self.request_id = request_id
self.disaggregate_info = disaggregate_info
self.block_tables = block_tables or []
self.idx = idx
self.need_prefill_tokens = need_prefill_tokens
self.error_msg = None
def get(self, key, default=None):
return getattr(self, key, default)
class _StubSocket:
"""Stub ZeroMQ-like socket used to capture sent payloads."""
def __init__(self, kind):
self.kind = kind
self.closed = False
self.bound = None
self.connected = None
self.sent = []
self.should_fail = False
def setsockopt(self, *_, **__):
return None
def bind(self, address):
self.bound = address
def connect(self, address):
self.connected = address
def send_multipart(self, payload):
if self.should_fail:
raise ValueError("send failure")
self.sent.append(payload)
def close(self):
self.closed = True
def recv_multipart(self): # pragma: no cover - not needed for tests
return []
class _StubContext:
"""Stub zmq.Context that records created sockets."""
def __init__(self):
self.sockets: list[_StubSocket] = []
def socket(self, kind):
sock = _StubSocket(kind)
self.sockets.append(sock)
return sock
class _StubPoller:
"""Stub zmq.Poller used by the connector for readiness checks."""
def __init__(self):
self.registered = []
def register(self, socket, event):
self.registered.append((socket, event))
def poll(self, timeout): # pragma: no cover - not used in tests
return []
def _make_stub_zmq():
return types.SimpleNamespace(
Context=_StubContext,
Poller=_StubPoller,
ROUTER=1,
DEALER=2,
POLLIN=3,
LINGER=4,
SNDHWM=5,
ROUTER_MANDATORY=6,
RECONNECT_IVL=7,
RECONNECT_IVL_MAX=8,
TCP_KEEPALIVE=9,
TCP_KEEPALIVE_IDLE=10,
TCP_KEEPALIVE_INTVL=11,
Again=RuntimeError,
ZMQError=RuntimeError,
)
def make_cfg(
innode_ports=None,
pd_comm_port=None,
*,
enable_expert_parallel=False,
data_parallel_size=1,
local_data_parallel_id=0,
):
parallel_config = SimpleNamespace(
enable_expert_parallel=enable_expert_parallel,
data_parallel_size=data_parallel_size,
local_data_parallel_id=local_data_parallel_id,
engine_worker_queue_port=[6100],
tensor_parallel_size=1,
device_ids="0,1",
)
cache_config = SimpleNamespace(pd_comm_port=pd_comm_port)
disaggregate_info = {
"cache_info": {"rdma": {"ip": "10.0.0.5", "port": 9001, "rdma_port": [12345], "current_id": None}}
}
return SimpleNamespace(
parallel_config=parallel_config,
cache_config=cache_config,
host_ip="127.0.0.1",
disaggregate_info=disaggregate_info,
innode_prefill_ports=innode_ports,
)
def make_task(request_id, role="prefill", protocol="rdma"):
cache_info = {}
if protocol == "rdma":
cache_info["rdma"] = {"ip": "10.1.0.1", "port": 9010, "current_id": None}
else:
cache_info["ipc"] = {"ip": "0.0.0.0", "port": 9200, "current_id": 7}
disaggregate_info = {
"role": role,
"transfer_protocol": protocol,
"cache_info": cache_info,
}
if role == "decode":
disaggregate_info["block_tables"] = [f"decode-{request_id}"]
block_tables = [f"blk-{request_id}"]
return DummyTask(request_id, disaggregate_info, block_tables=block_tables, idx=3, need_prefill_tokens=5)
def make_request_obj(request_id="req", **overrides):
payload = dict(
request_id=request_id,
prompt="hi",
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
arrival_time=0.0,
)
payload.update(overrides)
return Request(sampling_params=SamplingParams(), **payload)
@pytest.fixture(autouse=True)
def _patch_engine_worker_queue(monkeypatch, splitwise_stubs):
monkeypatch.setenv("FD_ENABLE_CACHE_TASK", "0")
monkeypatch.setenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")
monkeypatch.setenv("FD_PD_CHANGEABLE", "0")
monkeypatch.setenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")
monkeypatch.setattr(splitwise_connector, "EngineWorkerQueue", FakeEngineWorkerQueue)
def test_has_splitwise_tasks_detects_prefill_backlog():
cfg = make_cfg(innode_ports=[TEST_PORT_PREFILL])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_PREFILL)
queue = connector.connect_innode_instances[TEST_PORT_PREFILL]
queue.available_prefill_instances.size = 1
assert not connector.has_splitwise_tasks()
queue.available_prefill_instances.size = 0
assert connector.has_splitwise_tasks()
def test_dispatch_innode_splitwise_tasks_promotes_decode_role():
cfg = make_cfg(innode_ports=[TEST_PORT_INNODE_DISPATCH])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_INNODE_DISPATCH)
queue = connector.connect_innode_instances[TEST_PORT_INNODE_DISPATCH]
queue.prefill_ready = True
task = make_task("req-dispatch", role="prefill", protocol="ipc")
connector.dispatch_innode_splitwise_tasks([task], current_id=33)
assert queue.disaggregated_tasks[-1][0] == "prefill"
assert task.disaggregate_info["role"] == "decode"
assert task.disaggregate_info["cache_info"]["ipc"]["current_id"] == 33
def test_send_splitwise_tasks_dispatches_when_innode_ports_available():
cfg = make_cfg(innode_ports=[TEST_PORT_INNODE_SEND])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_INNODE_SEND)
connector.connect_innode_instances[TEST_PORT_INNODE_SEND].prefill_ready = True
task = make_task("req-prefill", role="prefill", protocol="ipc")
connector.send_splitwise_tasks([task], current_id=44)
assert connector.connect_innode_instances[TEST_PORT_INNODE_SEND].disaggregated_tasks
def test_send_splitwise_tasks_innode_rewrites_ports_for_decode_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_INNODE_DECODE)
task = make_task("req-innode", role="decode", protocol="ipc")
snapshot_port = connector.send_splitwise_tasks_innode([task], TEST_PORT_INNODE_DECODE)
recorded = connector.connect_innode_instances[TEST_PORT_INNODE_DECODE].disaggregated_tasks[-1]
assert snapshot_port == TEST_PORT_INNODE_DECODE
assert (
recorded[1][0].disaggregate_info["cache_info"]["ipc"]["port"]
== cfg.parallel_config.engine_worker_queue_port[0]
)
assert task.disaggregate_info["cache_info"]["ipc"]["port"] == TEST_PORT_INNODE_DECODE
def test_send_splitwise_tasks_rdma_routes_and_resets_state():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-remote", role="prefill", protocol="rdma")
connector.send_splitwise_tasks([task], current_id=55)
assert connector.sent_messages[-1][0] == "10.1.0.1:9010"
assert connector.sent_messages[-1][1] == "prefill"
assert connector.current_request_ids["req-remote"] == "init"
assert task.disaggregate_info["role"] == "prefill"
def test_send_cache_info_to_messager_batches_prefill_cache():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-prefill", role="prefill", protocol="ipc")
connector.send_cache_info_to_messager([task], current_id=11)
assert worker_queue.cache_infos[-1][0]["request_id"] == "req-prefill"
assert worker_queue.cache_infos[-1][0]["current_id"] == 11
def test_send_cache_info_to_prefill_rdma_triggers_remote_sync():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-decode", role="decode", protocol="rdma")
connector.send_cache_info_to_prefill([task])
assert connector.sent_messages[-1][1] == "cache_sync"
assert worker_queue.cache_infos == []
def test_send_cache_info_to_prefill_ipc_forwards_to_local_worker():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_DECODE_CACHE)
task = make_task("req-local", role="decode", protocol="ipc")
task.disaggregate_info["cache_info"]["ipc"]["port"] = TEST_PORT_DECODE_CACHE
connector.send_cache_info_to_prefill([task])
assert connector.connect_innode_instances[TEST_PORT_DECODE_CACHE].cache_infos[-1][0]["transfer_protocol"] == "ipc"
def test_send_cache_info_to_prefill_rdma_with_error_message_forwards_reason():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-err", role="decode", protocol="rdma")
task.error_msg = "remote boom"
connector.send_cache_info_to_prefill([task])
assert connector.sent_messages[-1][1] == "cache_sync"
assert "error_msg" in connector.sent_messages[-1][2][0]
def test_send_cache_info_to_messager_uses_cached_current_id_when_missing():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
skipped = DummyTask("req-skip", disaggregate_info=None)
task = make_task("req-prefill", role="prefill", protocol="ipc")
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = 42
connector.send_cache_info_to_messager([skipped, task], current_id=-1)
assert worker_queue.cache_infos[-1][0]["current_id"] == 42
def test_send_splitwise_tasks_innode_creates_connection_if_missing():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-create", role="decode", protocol="ipc")
selected_port = connector.send_splitwise_tasks_innode([task], TEST_PORT_INNODE_DECODE)
assert selected_port == TEST_PORT_INNODE_DECODE
assert connector.connect_innode_instances[TEST_PORT_INNODE_DECODE].disaggregated_tasks
def test_send_first_token_creates_connection_for_ipc_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
msg = {"transfer_protocol": "ipc", "cache_info": {"ipc": {"port": TEST_PORT_DECODE_FIRST_TOKEN}}}
task = make_task("req-first-missing", role="decode", protocol="ipc")
connector.send_first_token(msg, [task])
assert TEST_PORT_DECODE_FIRST_TOKEN in connector.connect_innode_instances
def test_get_push_socket_wraps_zmq_error(monkeypatch):
cfg = make_cfg(pd_comm_port=[TEST_PORT_PD_COMM_BASE])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.zmq_ctx = types.SimpleNamespace(
socket=lambda *_: (_ for _ in ()).throw(splitwise_connector.zmq.ZMQError("boom"))
)
with pytest.raises(ConnectionError):
connector._get_push_socket("1.2.3.4:9999")
def test_send_first_token_to_ipc_decode_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(TEST_PORT_DECODE_FIRST_TOKEN)
msg = {
"transfer_protocol": "ipc",
"cache_info": {"ipc": {"port": TEST_PORT_DECODE_FIRST_TOKEN}},
}
task = make_task("req-first", role="decode", protocol="ipc")
connector.send_first_token(msg, [task])
assert connector.connect_innode_instances[TEST_PORT_DECODE_FIRST_TOKEN].disaggregated_tasks[-1][0] == "decode"
def test_send_first_token_rdma_path(monkeypatch):
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
msg = {
"transfer_protocol": "rdma",
"cache_info": {"rdma": {"ip": "1.2.3.4", "port": 9123}},
}
task = make_task("req-first-rdma", role="decode", protocol="rdma")
connector.send_first_token(msg, task)
assert connector.sent_messages[-1][0] == "1.2.3.4:9123"
assert connector.sent_messages[-1][1] == "decode"
def test_check_decode_allocated_reports_finish_and_error():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-finish", role="prefill", protocol="rdma")
connector.current_request_ids["req-finish"] = "finished"
ok, msg = connector.check_decode_allocated(task)
assert ok
assert msg == ""
task2 = make_task("req-error", role="prefill", protocol="rdma")
connector.current_request_ids["req-error"] = "failed"
ok2, msg2 = connector.check_decode_allocated(task2)
assert not ok2
assert msg2 == "failed"
def test_process_cache_sync_records_status_and_forwards(monkeypatch):
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
payload = [
{"request_id": "req-a", "error_msg": "boom"},
{"request_id": "req-b"},
]
message = json.dumps({"type": "cache_sync", "payload": payload}).encode("utf-8")
connector._process_message(message)
assert connector.current_request_ids["req-a"] == "boom"
assert connector.current_request_ids["req-b"] == "finished"
assert worker_queue.cache_infos[-1] == payload
def test_handle_prefill_and_decode_messages():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
req = make_request_obj("req-handle")
connector._handle_prefill([req.to_dict()])
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
completion = CompletionOutput(index=0, send_idx=0, token_ids=[])
metrics = RequestMetrics(arrival_time=0.0)
output = RequestOutput("req-out", outputs=completion, metrics=metrics)
connector._handle_decode([output.to_dict()])
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
def test_close_connection_removes_socket_reference():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
class DummySocket:
"""Minimal socket stub used to verify close handling."""
def __init__(self):
self.closed = False
def close(self):
self.closed = True
dummy = DummySocket()
connector.push_sockets = {"test": dummy}
connector._close_connection("test")
assert dummy.closed
assert connector.push_sockets == {}
def test_send_message_initializes_network_and_serializes(monkeypatch):
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
class DummyExecutor:
def __init__(self, *_, **__):
self.calls = []
def submit(self, fn, *args, **kwargs):
self.calls.append((fn, args, kwargs))
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", DummyExecutor)
cfg = make_cfg(
pd_comm_port=[TEST_PORT_PD_COMM_BASE],
enable_expert_parallel=True,
data_parallel_size=2,
local_data_parallel_id=1,
)
worker_queue = FakeEngineWorkerQueue()
connector = SplitwiseConnector(cfg, worker_queue, object())
output = RequestOutput("req-zmq")
connector._send_message("127.0.0.1:9551", "decode", [output])
sock = connector.push_sockets["127.0.0.1:9551"]
assert json.loads(sock.sent[-1][1].decode("utf-8"))["type"] == "decode"
def test_send_message_handles_failures_and_resets_socket(monkeypatch):
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", lambda *_, **__: None)
cfg = make_cfg(pd_comm_port=[TEST_PORT_PD_COMM_FAIL])
worker_queue = FakeEngineWorkerQueue()
connector = SplitwiseConnector(cfg, worker_queue, object())
failing_socket = _StubSocket(2)
failing_socket.should_fail = True
connector.push_sockets["node"] = failing_socket
splitwise_connector.main_process_metrics.send_cache_failed_num.value = 0
output = RequestOutput("req-fail")
connector._send_message("node", "decode", [output])
assert "node" not in connector.push_sockets
assert splitwise_connector.main_process_metrics.send_cache_failed_num.value == 1