[Feature] [Benchmark]: add ZMQ-based FMQ implementation and benchmark tools (#5418)

* feat(fmq): add ZMQ-based FMQ implementation and benchmark tools

* move FMQ_CONFIG_JSON to envs

* fix top_p_candidates (#5400)

Co-authored-by: freeliuzc <lzc842650834@gmail.com>

* [RL] Support Rollout Routing Replay (#5321)

* [RL] Support Rollout Routing Replay

* add routing indices cache

* fix config bug and moe forward bug

* R3 Support GLM

* support eb4.5

* fix merge bug

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add routing replay ci

* support glm topk

* support orther top_k

* fix ci bug

* pre-commit

* only support chatcmpl

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>

* [Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)

* fix multi-inputs

* fix threshold

* fix threshold

* fix

* [BugFix]remove _execute_empty_input (#5396)

* Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)

This reverts commit 96d2d4877b.

* [New][RL] Support Rollout Routing Replay (#5405)

* [RL] Support Rollout Routing Replay

* add routing indices cache

* fix config bug and moe forward bug

* R3 Support GLM

* support eb4.5

* fix merge bug

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add routing replay ci

* support glm topk

* support orther top_k

* fix ci bug

* pre-commit

* only support chatcmpl

* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"

This reverts commit c45e064f3d.

* Fix XPU and NPU bug

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>

* bf16 deepseek (#5379)

* fix deepseek (#5410)

* Update tests/inter_communicator/test_fmq_factory.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update benchmarks/benchmark_fmq.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/inter_communicator/fmq.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: GoldPancake <56388518+Deleter-D@users.noreply.github.com>
Co-authored-by: freeliuzc <lzc842650834@gmail.com>
Co-authored-by: RAM <gstian5555@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com>
Co-authored-by: 周周周 <39978853+zhoutianzi666@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
SunLei
2025-12-08 22:04:49 +08:00
committed by GitHub
parent 364197c4b5
commit 5fb93d84f5
7 changed files with 848 additions and 0 deletions

233
benchmarks/benchmark_fmq.py Normal file
View File

@@ -0,0 +1,233 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import multiprocessing as mp
import os
import statistics
import time
from tqdm import tqdm
from fastdeploy.inter_communicator.fmq import FMQ
# ============================================================
# Producer Task
# ============================================================
async def producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q):
fmq = FMQ()
q = fmq.queue("mp_bench_latency", role="producer")
payload = b"x" * payload_size
# tqdm 进度条
pbar = tqdm(total=msg_count, desc=f"Producer-{proc_id}", position=proc_id, leave=True, disable=False)
t0 = time.perf_counter()
for i in range(msg_count):
send_ts = time.perf_counter()
await q.put(data={"pid": proc_id, "i": i, "send_ts": send_ts, "payload": payload}, shm_threshold=shm_threshold)
pbar.update(1)
# pbar.write(f"send {i}")
t1 = time.perf_counter()
result_q.put({"producer_id": proc_id, "count": msg_count, "time": t1 - t0})
pbar.close()
# wait for 2 seconds before closing
await asyncio.sleep(5)
def producer_process(proc_id, msg_count, payload_size, shm_threshold, result_q):
async def run():
await producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q)
asyncio.run(run())
# ============================================================
# Consumer Task
# ============================================================
async def consumer_task(consumer_id, total_msgs, result_q, consumer_event):
fmq = FMQ()
q = fmq.queue("mp_bench_latency", role="consumer")
consumer_event.set()
latencies = []
recv = 0
# tqdm 显示进度
pbar = tqdm(total=total_msgs, desc=f"Consumer-{consumer_id}", position=consumer_id + 1, leave=True, disable=False)
first_recv = None
last_recv = None
while recv < total_msgs:
msg = await q.get()
recv_ts = time.perf_counter()
if msg is None:
pbar.write("recv None")
continue
if first_recv is None:
first_recv = recv_ts
last_recv = recv_ts
send_ts = msg.payload["send_ts"]
latencies.append((recv_ts - send_ts) * 1000) # ms
pbar.update(1)
recv += 1
pbar.close()
result_q.put(
{"consumer_id": consumer_id, "latencies": latencies, "first_recv": first_recv, "last_recv": last_recv}
)
def consumer_process(consumer_id, total_msgs, result_q, consumer_event):
async def run():
await consumer_task(consumer_id, total_msgs, result_q, consumer_event)
asyncio.run(run())
# ============================================================
# MAIN benchmark
# ============================================================
def run_benchmark(
NUM_PRODUCERS=1,
NUM_CONSUMERS=1,
NUM_MESSAGES_PER_PRODUCER=1000,
PAYLOAD_SIZE=1 * 1024 * 1024,
SHM_THRESHOLD=1 * 1024 * 1024,
):
total_messages = NUM_PRODUCERS * NUM_MESSAGES_PER_PRODUCER
total_bytes = total_messages * PAYLOAD_SIZE
print(f"\nFastDeploy Message Queue Benchmark, pid:{os.getpid()}")
print(f"Producers: {NUM_PRODUCERS}")
print(f"Consumers: {NUM_CONSUMERS}")
print(f"Messages per producer: {NUM_MESSAGES_PER_PRODUCER}")
print(f"Total bytes: {total_bytes / 1024 / 1024 / 1024:.2f} GB")
print(f"Total messages: {total_messages:,}")
print(f"Payload per message: {PAYLOAD_SIZE / 1024 / 1024:.2f} MB")
mp.set_start_method("fork")
manager = mp.Manager()
result_q = manager.Queue()
# 两个信号事件
consumer_event = manager.Event()
procs = []
# Start Consumers
msgs_per_consumer = total_messages // NUM_CONSUMERS
for i in range(NUM_CONSUMERS):
p = mp.Process(target=consumer_process, args=(i, msgs_per_consumer, result_q, consumer_event))
procs.append(p)
p.start()
consumer_event.wait()
# Start Producers
for i in range(NUM_PRODUCERS):
p = mp.Process(
target=producer_process, args=(i, NUM_MESSAGES_PER_PRODUCER, PAYLOAD_SIZE, SHM_THRESHOLD, result_q)
)
procs.append(p)
p.start()
# Join
for p in procs:
p.join()
# Collect results
producer_stats = []
consumer_stats = {}
while not result_q.empty():
item = result_q.get()
if "producer_id" in item:
producer_stats.append(item)
if "consumer_id" in item:
consumer_stats[item["consumer_id"]] = item
# Producer stats
print("\nProducer Stats:")
for p in producer_stats:
throughput = p["count"] / p["time"]
bandwidth = (p["count"] * PAYLOAD_SIZE) / (1024**2 * p["time"])
print(
f"[Producer-{p['producer_id']}] Sent {p['count']:,} msgs "
f"in {p['time']:.3f} s | Throughput: {throughput:,.0f} msg/s | Bandwidth: {bandwidth:.2f} MB/s"
)
# Consumer latency stats
print("\nConsumer Latency Stats:")
all_latencies = []
first_recv_times = []
last_recv_times = []
for cid, data in consumer_stats.items():
lats = data["latencies"]
if len(lats) == 0:
continue
all_latencies.extend(lats)
first_recv_times.append(data["first_recv"])
last_recv_times.append(data["last_recv"])
avg = statistics.mean(lats)
p50 = statistics.median(lats)
p95 = statistics.quantiles(lats, n=20)[18]
p99 = statistics.quantiles(lats, n=100)[98]
print(
f"[Consumer-{cid}] msgs={len(lats):5d} | avg={avg:.3f} ms | "
f"P50={p50:.3f} ms | P95={p95:.3f} ms | P99={p99:.3f} ms"
)
# Global summary
if first_recv_times and last_recv_times:
total_time = max(last_recv_times) - min(first_recv_times)
global_throughput = total_messages / total_time
global_bandwidth = total_bytes / (1024**2 * total_time)
if all_latencies:
avg_latency = statistics.mean(all_latencies)
min_latency = min(all_latencies)
max_latency = max(all_latencies)
p50_latency = statistics.median(all_latencies)
p95_latency = statistics.quantiles(all_latencies, n=20)[18]
p99_latency = statistics.quantiles(all_latencies, n=100)[98]
else:
avg_latency = min_latency = max_latency = p50_latency = p95_latency = p99_latency = 0.0
print("\nGlobal Summary:")
print(f"Total messages : {total_messages:,}")
print(f"Total data : {total_bytes / 1024**2:.2f} MB")
print(f"Total time : {total_time:.3f} s")
print(f"Global throughput: {global_throughput:,.0f} msg/s")
print(f"Global bandwidth : {global_bandwidth:.2f} MB/s")
print(
f"Latency (ms) : avg={avg_latency:.3f} "
f"| min={min_latency:.3f} | max={max_latency:.3f} "
f"| P50={p50_latency:.3f} | P95={p95_latency:.3f} | P99={p99_latency:.3f}\n"
)
# Entry
if __name__ == "__main__":
run_benchmark()

View File

@@ -151,6 +151,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
"FMQ_CONFIG_JSON": lambda: os.getenv("FMQ_CONFIG_JSON", None),
}

View File

@@ -0,0 +1,347 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import json
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from multiprocessing import shared_memory
from multiprocessing.reduction import ForkingPickler
from typing import Any, Callable, Dict, Optional
import zmq
import zmq.asyncio
from fastdeploy import envs
from fastdeploy.utils import fmq_logger
# ==========================
# Config & Enum Definitions
# ==========================
class EndpointType(Enum):
QUEUE = "queue"
TOPIC = "topic"
class Role(Enum):
PRODUCER = "producer"
CONSUMER = "consumer"
@dataclass
class SocketOptions:
sndhwm: int = 0
rcvhwm: int = 0
linger: int = -1
sndbuf: int = 32 * 1024 * 1024
rcvbuf: int = 32 * 1024 * 1024
immediate: int = 1
def apply(self, socket: zmq.Socket, is_producer: bool):
# Apply socket-level configurations
socket.setsockopt(zmq.LINGER, self.linger)
socket.setsockopt(zmq.IMMEDIATE, self.immediate)
if is_producer:
socket.setsockopt(zmq.SNDHWM, self.sndhwm)
socket.setsockopt(zmq.SNDBUF, self.sndbuf)
else:
socket.setsockopt(zmq.RCVHWM, self.rcvhwm)
socket.setsockopt(zmq.RCVBUF, self.rcvbuf)
@dataclass
class Endpoint:
# Represents a single endpoint with protocol, address, io_threads, and copy behavior
protocol: str
address: str
io_threads: int = 1
copy: bool = False
@dataclass
class Config:
ipc_root: str = "/dev/shm"
io_threads: int = 1
copy: bool = False
endpoints: Dict[str, Endpoint] = field(default_factory=dict)
socket_config: SocketOptions = SocketOptions()
# ==========================
# Endpoint Manager
# ==========================
class EndpointManager:
config: Config = Config()
@classmethod
def load_config(cls, _ignored_file_path: str = None):
cfg_str = envs.FMQ_CONFIG_JSON
if cfg_str:
try:
custom_cfg = json.loads(cfg_str)
for key, value in vars(custom_cfg).items():
if value is not None:
setattr(cls.config, key, value)
except Exception as e:
fmq_logger.error(f"Failed to load FMQ config: {e}")
fmq_logger.info(f"Loaded FMQ config: {cls.config}")
@classmethod
def get_endpoint(cls, name: str) -> Endpoint:
# Retrieve endpoint object
if name in cls.config.endpoints:
return cls.config.endpoints[name]
# Fallback: auto-generate endpoint
address = f"{cls.config.ipc_root}/fmq_{name}.ipc"
return Endpoint(protocol="ipc", address=address)
# ==========================
# Shared Memory Descriptor
# ==========================
@dataclass
class Descriptor:
shm_name: str
size: int
@staticmethod
def create(data_bytes: bytes) -> "Descriptor":
# Create shared memory buffer and store payload
name = f"fmq_shm_{uuid.uuid4().hex}"
shm = shared_memory.SharedMemory(create=True, size=len(data_bytes), name=name)
shm.buf[: len(data_bytes)] = data_bytes
shm.close()
return Descriptor(shm_name=name, size=len(data_bytes))
def read_and_unlink(self) -> bytes:
# Read and cleanup shared memory
try:
shm = shared_memory.SharedMemory(name=self.shm_name)
data = bytes(shm.buf[: self.size])
shm.close()
shm.unlink()
return data
except FileNotFoundError:
return b""
# ==========================
# Message Wrapper
# ==========================
@dataclass
class Message:
payload: Any
msg_id: int = None
timestamp: float = field(default_factory=time.time)
descriptor: Optional[Descriptor] = None
def serialize(self) -> bytes:
# Serialize message
return ForkingPickler.dumps(self)
@staticmethod
def deserialize(data: bytes) -> "Message":
# Deserialize message
return ForkingPickler.loads(data)
# ==========================
# Base Component
# ==========================
class BaseComponent:
def __init__(self, context: zmq.asyncio.Context, endpoint: Endpoint):
self.context = context
self.endpoint = endpoint
self.socket = None
self.lock = asyncio.Lock()
async def close(self):
# Close socket
if self.socket:
self.socket.close()
# ==========================
# FIFO Queue
# ==========================
class Queue(BaseComponent):
def __init__(self, context, name: str, role: str = "producer"):
endpoint = EndpointManager.get_endpoint(name)
super().__init__(context, endpoint)
self.name = name
self.role = Role(role)
self.copy = endpoint.copy
self.socket_conf = EndpointManager.config.socket_config
self._msg_id = 0
full_ep = f"{endpoint.protocol}://{endpoint.address}"
self.socket = self.context.socket(zmq.PUSH if self.role == Role.PRODUCER else zmq.PULL)
self.socket_conf.apply(self.socket, self.role == Role.PRODUCER)
if self.role == Role.PRODUCER:
self.socket.connect(full_ep)
else:
self.socket.bind(full_ep)
fmq_logger.info(f"Queue {name} initialized on {full_ep}")
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
"""
Send data to the queue.
Args:
data: The data to send. Can be any serializable object or bytes.
shm_threshold: Size threshold in bytes. If the data is of type bytes and its size is
greater than or equal to this threshold, shared memory will be used to send the message.
Default is 1MB (1024 * 1024 bytes).
Raises:
PermissionError: If called by a non-producer role.
"""
if self.role != Role.PRODUCER:
raise PermissionError("Only producers can send messages.")
desc = None
payload = data
if isinstance(data, bytes) and len(data) >= shm_threshold:
desc = Descriptor.create(data)
payload = None
msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc)
raw = msg.serialize()
async with self.lock:
await self.socket.send(raw, copy=self.copy)
self._msg_id += 1
async def get(self, timeout: int = None) -> Optional[Message]:
# Receive data from queue
if self.role != Role.CONSUMER:
raise PermissionError("Only consumers can get messages.")
try:
if timeout:
raw = await asyncio.wait_for(self.socket.recv(), timeout / 1000)
else:
raw = await self.socket.recv(copy=self.copy)
except asyncio.TimeoutError:
fmq_logger.error(f"Timeout receiving message on {self.name}")
return None
msg = Message.deserialize(raw)
if msg.descriptor:
msg.payload = msg.descriptor.read_and_unlink()
self._msg_id += 1
return msg
# ==========================
# Pub/Sub Topic
# ==========================
class Topic(BaseComponent):
def __init__(self, context, name: str):
endpoint = EndpointManager.get_endpoint(name)
super().__init__(context, endpoint)
self.name = name
self._pub_socket = None
self._sub_socket = None
self._task = None
async def pub(self, data: Any):
# Publish a message
if not self._pub_socket:
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
self._pub_socket = self.context.socket(zmq.PUB)
self._pub_socket.bind(ep)
await asyncio.sleep(0.05)
msg = Message(payload=data)
async with self.lock:
await self._pub_socket.send(msg.serialize())
async def sub(self, callback: Callable[[Message], Any]):
# Subscribe and handle messages
if not self._sub_socket:
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
self._sub_socket = self.context.socket(zmq.SUB)
self._sub_socket.connect(ep)
self._sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
async def loop():
while True:
raw = await self._sub_socket.recv()
msg = Message.deserialize(raw)
result = callback(msg)
if asyncio.iscoroutine(result):
await result
self._task = asyncio.create_task(loop())
# ==========================
# FMQ Main Interface
# ==========================
class FMQ:
_instance = None
_context = None
def __new__(cls, config_path="fmq_config.json"):
if cls._instance is None:
cls._instance = super().__new__(cls)
EndpointManager.load_config()
# Determine IO threads based on global defaults
io_threads = 1
if EndpointManager.config.endpoints:
# Use max io_threads among all endpoints
io_threads = max(ep.io_threads for ep in EndpointManager.config.endpoints.values())
cls._context = zmq.asyncio.Context(io_threads=io_threads)
return cls._instance
def queue(self, name: str, role="producer") -> Queue:
return Queue(self._context, name, role)
def topic(self, name: str) -> Topic:
return Topic(self._context, name)
async def destroy(self):
# Destroy ZeroMQ context
self._context.term()

View File

@@ -0,0 +1,83 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.inter_communicator.fmq import FMQ
class FMQFactory:
"""
Static factory for creating the four standard FMQ queues:
1. q_a2e: api server --> engine
2. q_e2w: engine --> worker
3. q_w2e: worker --> engine
4. q_e2a: engine --> api server
API Server: q_a2e producer / q_e2a consumer
Engine: q_a2e consumer / q_e2w producer / q_w2e consumer / q_e2a producer
Worker: q_e2w consumer / q_w2e producer
"""
_fmq = FMQ()
# ------------------------------
# API → Engine
# ------------------------------
@classmethod
def q_a2e_producer(cls):
return cls._fmq.queue("q_a2e", role="producer")
@classmethod
def q_a2e_consumer(cls):
return cls._fmq.queue("q_a2e", role="consumer")
# ------------------------------
# Engine → Worker
# ------------------------------
@classmethod
def q_e2w_producer(cls):
return cls._fmq.queue("q_e2w", role="producer")
@classmethod
def q_e2w_consumer(cls):
return cls._fmq.queue("q_e2w", role="consumer")
# ------------------------------
# Worker → Engine
# ------------------------------
@classmethod
def q_w2e_producer(cls):
return cls._fmq.queue("q_w2e", role="producer")
@classmethod
def q_w2e_consumer(cls):
return cls._fmq.queue("q_w2e", role="consumer")
# ------------------------------
# Engine → API
# ------------------------------
@classmethod
def q_e2a_producer(cls):
return cls._fmq.queue("q_e2a", role="producer")
@classmethod
def q_e2a_consumer(cls):
return cls._fmq.queue("q_e2a", role="consumer")
# ------------------------------
# Destroy context
# ------------------------------
@classmethod
async def destroy(cls):
await cls._fmq.destroy()

View File

@@ -1051,6 +1051,7 @@ spec_logger = get_logger("speculate", "speculate.log")
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log")
router_logger = get_logger("router", "router.log")
fmq_logger = get_logger("fmq", "fmq.log")
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:

View File

@@ -0,0 +1,92 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import json
import os
import unittest
from fastdeploy.inter_communicator.fmq import FMQ, Message
# Prepare environment config for testing
cfg = {
"ipc_root": "/dev/shm",
"io_threads": 1,
"copy": False,
"endpoints": {
"test_queue": {"protocol": "ipc", "address": "/dev/shm/fmq_test_queue.ipc", "io_threads": 1, "copy": False},
"test_topic": {"protocol": "ipc", "address": "/dev/shm/fmq_test_topic.ipc", "io_threads": 1, "copy": False},
},
}
os.environ["FMQ_CONFIG_JSON"] = json.dumps(cfg)
class TestFMQ(unittest.TestCase):
def setUp(self):
self.fmq = FMQ()
def test_queue_send_receive(self):
async def run_test():
producer = self.fmq.queue("test_queue", role="producer")
consumer = self.fmq.queue("test_queue", role="consumer")
test_data = b"hello world"
await producer.put(test_data)
msg = await consumer.get(timeout=1000)
self.assertIsNotNone(msg)
self.assertEqual(msg.payload, test_data)
asyncio.run(run_test())
def test_queue_large_shm_transfer(self):
async def run_test():
producer = self.fmq.queue("test_queue", role="producer")
consumer = self.fmq.queue("test_queue", role="consumer")
large_data = b"x" * (2 * 1024 * 1024) # > 1MB
await producer.put(large_data)
msg = await consumer.get(timeout=1000)
self.assertIsNotNone(msg)
self.assertEqual(msg.payload, large_data)
self.assertIsNotNone(msg.descriptor)
asyncio.run(run_test())
def test_topic_pub_sub(self):
received = []
async def run_test():
topic = self.fmq.topic("test_topic")
async def callback(msg: Message):
received.append(msg.payload)
await topic.sub(callback)
await asyncio.sleep(0.1) # allow SUB to connect
await topic.pub("hello")
await asyncio.sleep(0.2)
self.assertIn("hello", received)
asyncio.run(run_test())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,91 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from fastdeploy.inter_communicator.fmq import Message
from fastdeploy.inter_communicator.fmq_factory import FMQFactory as factory
class TestFMQFactory(unittest.IsolatedAsyncioTestCase):
async def test_create_queues(self):
"""Test whether all producer/consumer queues can be created."""
q1 = factory.q_a2e_producer()
q2 = factory.q_a2e_consumer()
q3 = factory.q_e2w_producer()
q4 = factory.q_e2w_consumer()
q5 = factory.q_w2e_producer()
q6 = factory.q_w2e_consumer()
q7 = factory.q_e2a_producer()
q8 = factory.q_e2a_consumer()
self.assertEqual(q1.name, "q_a2e")
self.assertEqual(q2.name, "q_a2e")
self.assertEqual(q3.name, "q_e2w")
self.assertEqual(q4.name, "q_e2w")
self.assertEqual(q5.name, "q_w2e")
self.assertEqual(q6.name, "q_w2e")
self.assertEqual(q7.name, "q_e2a")
self.assertEqual(q8.name, "q_e2a")
# 同一进程内 context 应相同
self.assertIs(q1.context, q2.context)
self.assertIs(q1.context, q3.context)
async def test_message_roundtrip(self):
"""测试 producer → consumer 消息流转"""
producer = factory.q_a2e_producer()
consumer = factory.q_a2e_consumer()
payload = {"k": "v"}
await producer.put(payload)
msg = await consumer.get(timeout=1500)
self.assertIsInstance(msg, Message)
self.assertEqual(msg.payload, payload)
async def test_multi_queue_independence(self):
"""测试多个队列互不干扰"""
prod_a2e = factory.q_a2e_producer()
cons_a2e = factory.q_a2e_consumer()
prod_e2w = factory.q_e2w_producer()
cons_e2w = factory.q_e2w_consumer()
await prod_a2e.put("msg_api")
await prod_e2w.put("msg_worker")
msg1 = await cons_a2e.get(timeout=1500)
msg2 = await cons_e2w.get(timeout=1500)
self.assertEqual(msg1.payload, "msg_api")
self.assertEqual(msg2.payload, "msg_worker")
async def test_shared_context(self):
"""验证 FMQFactory 始终返回同一个 context (单进程)"""
q1 = factory.q_a2e_producer()
q2 = factory.q_e2w_consumer()
q3 = factory.q_e2a_producer()
self.assertIs(q1.context, q2.context)
self.assertIs(q1.context, q3.context)
if __name__ == "__main__":
unittest.main()