diff --git a/benchmarks/benchmark_fmq.py b/benchmarks/benchmark_fmq.py new file mode 100644 index 000000000..3878f790c --- /dev/null +++ b/benchmarks/benchmark_fmq.py @@ -0,0 +1,233 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import multiprocessing as mp +import os +import statistics +import time + +from tqdm import tqdm + +from fastdeploy.inter_communicator.fmq import FMQ + + +# ============================================================ +# Producer Task +# ============================================================ +async def producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q): + fmq = FMQ() + q = fmq.queue("mp_bench_latency", role="producer") + payload = b"x" * payload_size + + # tqdm 进度条 + pbar = tqdm(total=msg_count, desc=f"Producer-{proc_id}", position=proc_id, leave=True, disable=False) + + t0 = time.perf_counter() + for i in range(msg_count): + send_ts = time.perf_counter() + await q.put(data={"pid": proc_id, "i": i, "send_ts": send_ts, "payload": payload}, shm_threshold=shm_threshold) + pbar.update(1) + # pbar.write(f"send {i}") + t1 = time.perf_counter() + result_q.put({"producer_id": proc_id, "count": msg_count, "time": t1 - t0}) + + pbar.close() + + # wait for 2 seconds before closing + await asyncio.sleep(5) + + +def producer_process(proc_id, msg_count, payload_size, shm_threshold, result_q): + async def run(): + await producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q) + + asyncio.run(run()) + + +# ============================================================ +# Consumer Task +# ============================================================ +async def consumer_task(consumer_id, total_msgs, result_q, consumer_event): + fmq = FMQ() + q = fmq.queue("mp_bench_latency", role="consumer") + consumer_event.set() + + latencies = [] + recv = 0 + + # tqdm 显示进度 + pbar = tqdm(total=total_msgs, desc=f"Consumer-{consumer_id}", position=consumer_id + 1, leave=True, disable=False) + + first_recv = None + last_recv = None + + while recv < total_msgs: + msg = await q.get() + recv_ts = time.perf_counter() + if msg is None: + pbar.write("recv None") + continue + if first_recv is None: + first_recv = recv_ts + last_recv = recv_ts + send_ts = msg.payload["send_ts"] + latencies.append((recv_ts - send_ts) * 1000) # ms + pbar.update(1) + recv += 1 + + pbar.close() + + result_q.put( + {"consumer_id": consumer_id, "latencies": latencies, "first_recv": first_recv, "last_recv": last_recv} + ) + + +def consumer_process(consumer_id, total_msgs, result_q, consumer_event): + async def run(): + await consumer_task(consumer_id, total_msgs, result_q, consumer_event) + + asyncio.run(run()) + + +# ============================================================ +# MAIN benchmark +# ============================================================ +def run_benchmark( + NUM_PRODUCERS=1, + NUM_CONSUMERS=1, + NUM_MESSAGES_PER_PRODUCER=1000, + PAYLOAD_SIZE=1 * 1024 * 1024, + SHM_THRESHOLD=1 * 1024 * 1024, +): + total_messages = NUM_PRODUCERS * NUM_MESSAGES_PER_PRODUCER + total_bytes = total_messages * PAYLOAD_SIZE + + print(f"\nFastDeploy Message Queue Benchmark, pid:{os.getpid()}") + print(f"Producers: {NUM_PRODUCERS}") + print(f"Consumers: {NUM_CONSUMERS}") + print(f"Messages per producer: {NUM_MESSAGES_PER_PRODUCER}") + print(f"Total bytes: {total_bytes / 1024 / 1024 / 1024:.2f} GB") + print(f"Total messages: {total_messages:,}") + print(f"Payload per message: {PAYLOAD_SIZE / 1024 / 1024:.2f} MB") + + mp.set_start_method("fork") + manager = mp.Manager() + result_q = manager.Queue() + + # 两个信号事件 + consumer_event = manager.Event() + + procs = [] + + # Start Consumers + msgs_per_consumer = total_messages // NUM_CONSUMERS + for i in range(NUM_CONSUMERS): + p = mp.Process(target=consumer_process, args=(i, msgs_per_consumer, result_q, consumer_event)) + procs.append(p) + p.start() + + consumer_event.wait() + + # Start Producers + for i in range(NUM_PRODUCERS): + p = mp.Process( + target=producer_process, args=(i, NUM_MESSAGES_PER_PRODUCER, PAYLOAD_SIZE, SHM_THRESHOLD, result_q) + ) + procs.append(p) + p.start() + + # Join + for p in procs: + p.join() + + # Collect results + producer_stats = [] + consumer_stats = {} + + while not result_q.empty(): + item = result_q.get() + if "producer_id" in item: + producer_stats.append(item) + if "consumer_id" in item: + consumer_stats[item["consumer_id"]] = item + + # Producer stats + print("\nProducer Stats:") + for p in producer_stats: + throughput = p["count"] / p["time"] + bandwidth = (p["count"] * PAYLOAD_SIZE) / (1024**2 * p["time"]) + print( + f"[Producer-{p['producer_id']}] Sent {p['count']:,} msgs " + f"in {p['time']:.3f} s | Throughput: {throughput:,.0f} msg/s | Bandwidth: {bandwidth:.2f} MB/s" + ) + + # Consumer latency stats + print("\nConsumer Latency Stats:") + all_latencies = [] + first_recv_times = [] + last_recv_times = [] + + for cid, data in consumer_stats.items(): + lats = data["latencies"] + if len(lats) == 0: + continue + all_latencies.extend(lats) + first_recv_times.append(data["first_recv"]) + last_recv_times.append(data["last_recv"]) + + avg = statistics.mean(lats) + p50 = statistics.median(lats) + p95 = statistics.quantiles(lats, n=20)[18] + p99 = statistics.quantiles(lats, n=100)[98] + + print( + f"[Consumer-{cid}] msgs={len(lats):5d} | avg={avg:.3f} ms | " + f"P50={p50:.3f} ms | P95={p95:.3f} ms | P99={p99:.3f} ms" + ) + + # Global summary + if first_recv_times and last_recv_times: + total_time = max(last_recv_times) - min(first_recv_times) + global_throughput = total_messages / total_time + global_bandwidth = total_bytes / (1024**2 * total_time) + + if all_latencies: + avg_latency = statistics.mean(all_latencies) + min_latency = min(all_latencies) + max_latency = max(all_latencies) + p50_latency = statistics.median(all_latencies) + p95_latency = statistics.quantiles(all_latencies, n=20)[18] + p99_latency = statistics.quantiles(all_latencies, n=100)[98] + else: + avg_latency = min_latency = max_latency = p50_latency = p95_latency = p99_latency = 0.0 + + print("\nGlobal Summary:") + print(f"Total messages : {total_messages:,}") + print(f"Total data : {total_bytes / 1024**2:.2f} MB") + print(f"Total time : {total_time:.3f} s") + print(f"Global throughput: {global_throughput:,.0f} msg/s") + print(f"Global bandwidth : {global_bandwidth:.2f} MB/s") + print( + f"Latency (ms) : avg={avg_latency:.3f} " + f"| min={min_latency:.3f} | max={max_latency:.3f} " + f"| P50={p50_latency:.3f} | P95={p95_latency:.3f} | P99={p99_latency:.3f}\n" + ) + + +# Entry +if __name__ == "__main__": + run_benchmark() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 93f135d09..dc734af5e 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -151,6 +151,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU" "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), + "FMQ_CONFIG_JSON": lambda: os.getenv("FMQ_CONFIG_JSON", None), } diff --git a/fastdeploy/inter_communicator/fmq.py b/fastdeploy/inter_communicator/fmq.py new file mode 100644 index 000000000..f2c98196c --- /dev/null +++ b/fastdeploy/inter_communicator/fmq.py @@ -0,0 +1,347 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import json +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import shared_memory +from multiprocessing.reduction import ForkingPickler +from typing import Any, Callable, Dict, Optional + +import zmq +import zmq.asyncio + +from fastdeploy import envs +from fastdeploy.utils import fmq_logger + +# ========================== +# Config & Enum Definitions +# ========================== + + +class EndpointType(Enum): + QUEUE = "queue" + TOPIC = "topic" + + +class Role(Enum): + PRODUCER = "producer" + CONSUMER = "consumer" + + +@dataclass +class SocketOptions: + sndhwm: int = 0 + rcvhwm: int = 0 + linger: int = -1 + sndbuf: int = 32 * 1024 * 1024 + rcvbuf: int = 32 * 1024 * 1024 + immediate: int = 1 + + def apply(self, socket: zmq.Socket, is_producer: bool): + # Apply socket-level configurations + socket.setsockopt(zmq.LINGER, self.linger) + socket.setsockopt(zmq.IMMEDIATE, self.immediate) + + if is_producer: + socket.setsockopt(zmq.SNDHWM, self.sndhwm) + socket.setsockopt(zmq.SNDBUF, self.sndbuf) + else: + socket.setsockopt(zmq.RCVHWM, self.rcvhwm) + socket.setsockopt(zmq.RCVBUF, self.rcvbuf) + + +@dataclass +class Endpoint: + # Represents a single endpoint with protocol, address, io_threads, and copy behavior + protocol: str + address: str + io_threads: int = 1 + copy: bool = False + + +@dataclass +class Config: + ipc_root: str = "/dev/shm" + io_threads: int = 1 + copy: bool = False + endpoints: Dict[str, Endpoint] = field(default_factory=dict) + socket_config: SocketOptions = SocketOptions() + + +# ========================== +# Endpoint Manager +# ========================== + + +class EndpointManager: + config: Config = Config() + + @classmethod + def load_config(cls, _ignored_file_path: str = None): + cfg_str = envs.FMQ_CONFIG_JSON + if cfg_str: + try: + custom_cfg = json.loads(cfg_str) + for key, value in vars(custom_cfg).items(): + if value is not None: + setattr(cls.config, key, value) + except Exception as e: + fmq_logger.error(f"Failed to load FMQ config: {e}") + fmq_logger.info(f"Loaded FMQ config: {cls.config}") + + @classmethod + def get_endpoint(cls, name: str) -> Endpoint: + # Retrieve endpoint object + if name in cls.config.endpoints: + return cls.config.endpoints[name] + + # Fallback: auto-generate endpoint + address = f"{cls.config.ipc_root}/fmq_{name}.ipc" + return Endpoint(protocol="ipc", address=address) + + +# ========================== +# Shared Memory Descriptor +# ========================== + + +@dataclass +class Descriptor: + shm_name: str + size: int + + @staticmethod + def create(data_bytes: bytes) -> "Descriptor": + # Create shared memory buffer and store payload + name = f"fmq_shm_{uuid.uuid4().hex}" + shm = shared_memory.SharedMemory(create=True, size=len(data_bytes), name=name) + shm.buf[: len(data_bytes)] = data_bytes + shm.close() + return Descriptor(shm_name=name, size=len(data_bytes)) + + def read_and_unlink(self) -> bytes: + # Read and cleanup shared memory + try: + shm = shared_memory.SharedMemory(name=self.shm_name) + data = bytes(shm.buf[: self.size]) + shm.close() + shm.unlink() + return data + except FileNotFoundError: + return b"" + + +# ========================== +# Message Wrapper +# ========================== + + +@dataclass +class Message: + payload: Any + msg_id: int = None + timestamp: float = field(default_factory=time.time) + descriptor: Optional[Descriptor] = None + + def serialize(self) -> bytes: + # Serialize message + return ForkingPickler.dumps(self) + + @staticmethod + def deserialize(data: bytes) -> "Message": + # Deserialize message + return ForkingPickler.loads(data) + + +# ========================== +# Base Component +# ========================== + + +class BaseComponent: + def __init__(self, context: zmq.asyncio.Context, endpoint: Endpoint): + self.context = context + self.endpoint = endpoint + self.socket = None + self.lock = asyncio.Lock() + + async def close(self): + # Close socket + if self.socket: + self.socket.close() + + +# ========================== +# FIFO Queue +# ========================== + + +class Queue(BaseComponent): + def __init__(self, context, name: str, role: str = "producer"): + endpoint = EndpointManager.get_endpoint(name) + super().__init__(context, endpoint) + + self.name = name + self.role = Role(role) + self.copy = endpoint.copy + self.socket_conf = EndpointManager.config.socket_config + self._msg_id = 0 + + full_ep = f"{endpoint.protocol}://{endpoint.address}" + + self.socket = self.context.socket(zmq.PUSH if self.role == Role.PRODUCER else zmq.PULL) + self.socket_conf.apply(self.socket, self.role == Role.PRODUCER) + + if self.role == Role.PRODUCER: + self.socket.connect(full_ep) + else: + self.socket.bind(full_ep) + + fmq_logger.info(f"Queue {name} initialized on {full_ep}") + + async def put(self, data: Any, shm_threshold: int = 1024 * 1024): + """ + Send data to the queue. + + Args: + data: The data to send. Can be any serializable object or bytes. + shm_threshold: Size threshold in bytes. If the data is of type bytes and its size is + greater than or equal to this threshold, shared memory will be used to send the message. + Default is 1MB (1024 * 1024 bytes). + + Raises: + PermissionError: If called by a non-producer role. + """ + if self.role != Role.PRODUCER: + raise PermissionError("Only producers can send messages.") + + desc = None + payload = data + + if isinstance(data, bytes) and len(data) >= shm_threshold: + desc = Descriptor.create(data) + payload = None + + msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc) + raw = msg.serialize() + + async with self.lock: + await self.socket.send(raw, copy=self.copy) + self._msg_id += 1 + + async def get(self, timeout: int = None) -> Optional[Message]: + # Receive data from queue + if self.role != Role.CONSUMER: + raise PermissionError("Only consumers can get messages.") + + try: + if timeout: + raw = await asyncio.wait_for(self.socket.recv(), timeout / 1000) + else: + raw = await self.socket.recv(copy=self.copy) + except asyncio.TimeoutError: + fmq_logger.error(f"Timeout receiving message on {self.name}") + return None + + msg = Message.deserialize(raw) + if msg.descriptor: + msg.payload = msg.descriptor.read_and_unlink() + + self._msg_id += 1 + return msg + + +# ========================== +# Pub/Sub Topic +# ========================== + + +class Topic(BaseComponent): + def __init__(self, context, name: str): + endpoint = EndpointManager.get_endpoint(name) + super().__init__(context, endpoint) + self.name = name + self._pub_socket = None + self._sub_socket = None + self._task = None + + async def pub(self, data: Any): + # Publish a message + if not self._pub_socket: + ep = f"{self.endpoint.protocol}://{self.endpoint.address}" + self._pub_socket = self.context.socket(zmq.PUB) + self._pub_socket.bind(ep) + await asyncio.sleep(0.05) + + msg = Message(payload=data) + async with self.lock: + await self._pub_socket.send(msg.serialize()) + + async def sub(self, callback: Callable[[Message], Any]): + # Subscribe and handle messages + if not self._sub_socket: + ep = f"{self.endpoint.protocol}://{self.endpoint.address}" + self._sub_socket = self.context.socket(zmq.SUB) + self._sub_socket.connect(ep) + self._sub_socket.setsockopt_string(zmq.SUBSCRIBE, "") + + async def loop(): + while True: + raw = await self._sub_socket.recv() + msg = Message.deserialize(raw) + result = callback(msg) + if asyncio.iscoroutine(result): + await result + + self._task = asyncio.create_task(loop()) + + +# ========================== +# FMQ Main Interface +# ========================== + + +class FMQ: + _instance = None + _context = None + + def __new__(cls, config_path="fmq_config.json"): + if cls._instance is None: + cls._instance = super().__new__(cls) + EndpointManager.load_config() + + # Determine IO threads based on global defaults + io_threads = 1 + if EndpointManager.config.endpoints: + # Use max io_threads among all endpoints + io_threads = max(ep.io_threads for ep in EndpointManager.config.endpoints.values()) + + cls._context = zmq.asyncio.Context(io_threads=io_threads) + return cls._instance + + def queue(self, name: str, role="producer") -> Queue: + return Queue(self._context, name, role) + + def topic(self, name: str) -> Topic: + return Topic(self._context, name) + + async def destroy(self): + # Destroy ZeroMQ context + self._context.term() diff --git a/fastdeploy/inter_communicator/fmq_factory.py b/fastdeploy/inter_communicator/fmq_factory.py new file mode 100644 index 000000000..d1c8e4dd2 --- /dev/null +++ b/fastdeploy/inter_communicator/fmq_factory.py @@ -0,0 +1,83 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.inter_communicator.fmq import FMQ + + +class FMQFactory: + """ + Static factory for creating the four standard FMQ queues: + 1. q_a2e: api server --> engine + 2. q_e2w: engine --> worker + 3. q_w2e: worker --> engine + 4. q_e2a: engine --> api server + API Server: q_a2e producer / q_e2a consumer + Engine: q_a2e consumer / q_e2w producer / q_w2e consumer / q_e2a producer + Worker: q_e2w consumer / q_w2e producer + """ + + _fmq = FMQ() + + # ------------------------------ + # API → Engine + # ------------------------------ + @classmethod + def q_a2e_producer(cls): + return cls._fmq.queue("q_a2e", role="producer") + + @classmethod + def q_a2e_consumer(cls): + return cls._fmq.queue("q_a2e", role="consumer") + + # ------------------------------ + # Engine → Worker + # ------------------------------ + @classmethod + def q_e2w_producer(cls): + return cls._fmq.queue("q_e2w", role="producer") + + @classmethod + def q_e2w_consumer(cls): + return cls._fmq.queue("q_e2w", role="consumer") + + # ------------------------------ + # Worker → Engine + # ------------------------------ + @classmethod + def q_w2e_producer(cls): + return cls._fmq.queue("q_w2e", role="producer") + + @classmethod + def q_w2e_consumer(cls): + return cls._fmq.queue("q_w2e", role="consumer") + + # ------------------------------ + # Engine → API + # ------------------------------ + @classmethod + def q_e2a_producer(cls): + return cls._fmq.queue("q_e2a", role="producer") + + @classmethod + def q_e2a_consumer(cls): + return cls._fmq.queue("q_e2a", role="consumer") + + # ------------------------------ + # Destroy context + # ------------------------------ + @classmethod + async def destroy(cls): + await cls._fmq.destroy() diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a0878fa7c..97a975f4e 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -1051,6 +1051,7 @@ spec_logger = get_logger("speculate", "speculate.log") zmq_client_logger = get_logger("zmq_client", "zmq_client.log") trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log") router_logger = get_logger("router", "router.log") +fmq_logger = get_logger("fmq", "fmq.log") def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: diff --git a/tests/inter_communicator/test_fmq.py b/tests/inter_communicator/test_fmq.py new file mode 100644 index 000000000..a7d6a8153 --- /dev/null +++ b/tests/inter_communicator/test_fmq.py @@ -0,0 +1,92 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import json +import os +import unittest + +from fastdeploy.inter_communicator.fmq import FMQ, Message + +# Prepare environment config for testing +cfg = { + "ipc_root": "/dev/shm", + "io_threads": 1, + "copy": False, + "endpoints": { + "test_queue": {"protocol": "ipc", "address": "/dev/shm/fmq_test_queue.ipc", "io_threads": 1, "copy": False}, + "test_topic": {"protocol": "ipc", "address": "/dev/shm/fmq_test_topic.ipc", "io_threads": 1, "copy": False}, + }, +} +os.environ["FMQ_CONFIG_JSON"] = json.dumps(cfg) + + +class TestFMQ(unittest.TestCase): + + def setUp(self): + self.fmq = FMQ() + + def test_queue_send_receive(self): + async def run_test(): + producer = self.fmq.queue("test_queue", role="producer") + consumer = self.fmq.queue("test_queue", role="consumer") + + test_data = b"hello world" + await producer.put(test_data) + msg = await consumer.get(timeout=1000) + + self.assertIsNotNone(msg) + self.assertEqual(msg.payload, test_data) + + asyncio.run(run_test()) + + def test_queue_large_shm_transfer(self): + async def run_test(): + producer = self.fmq.queue("test_queue", role="producer") + consumer = self.fmq.queue("test_queue", role="consumer") + + large_data = b"x" * (2 * 1024 * 1024) # > 1MB + await producer.put(large_data) + msg = await consumer.get(timeout=1000) + + self.assertIsNotNone(msg) + self.assertEqual(msg.payload, large_data) + self.assertIsNotNone(msg.descriptor) + + asyncio.run(run_test()) + + def test_topic_pub_sub(self): + received = [] + + async def run_test(): + topic = self.fmq.topic("test_topic") + + async def callback(msg: Message): + received.append(msg.payload) + + await topic.sub(callback) + await asyncio.sleep(0.1) # allow SUB to connect + + await topic.pub("hello") + await asyncio.sleep(0.2) + + self.assertIn("hello", received) + + asyncio.run(run_test()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/inter_communicator/test_fmq_factory.py b/tests/inter_communicator/test_fmq_factory.py new file mode 100644 index 000000000..50da82f4f --- /dev/null +++ b/tests/inter_communicator/test_fmq_factory.py @@ -0,0 +1,91 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from fastdeploy.inter_communicator.fmq import Message +from fastdeploy.inter_communicator.fmq_factory import FMQFactory as factory + + +class TestFMQFactory(unittest.IsolatedAsyncioTestCase): + + async def test_create_queues(self): + """Test whether all producer/consumer queues can be created.""" + q1 = factory.q_a2e_producer() + q2 = factory.q_a2e_consumer() + q3 = factory.q_e2w_producer() + q4 = factory.q_e2w_consumer() + q5 = factory.q_w2e_producer() + q6 = factory.q_w2e_consumer() + q7 = factory.q_e2a_producer() + q8 = factory.q_e2a_consumer() + + self.assertEqual(q1.name, "q_a2e") + self.assertEqual(q2.name, "q_a2e") + self.assertEqual(q3.name, "q_e2w") + self.assertEqual(q4.name, "q_e2w") + self.assertEqual(q5.name, "q_w2e") + self.assertEqual(q6.name, "q_w2e") + self.assertEqual(q7.name, "q_e2a") + self.assertEqual(q8.name, "q_e2a") + + # 同一进程内 context 应相同 + self.assertIs(q1.context, q2.context) + self.assertIs(q1.context, q3.context) + + async def test_message_roundtrip(self): + """测试 producer → consumer 消息流转""" + producer = factory.q_a2e_producer() + consumer = factory.q_a2e_consumer() + + payload = {"k": "v"} + + await producer.put(payload) + msg = await consumer.get(timeout=1500) + + self.assertIsInstance(msg, Message) + self.assertEqual(msg.payload, payload) + + async def test_multi_queue_independence(self): + """测试多个队列互不干扰""" + + prod_a2e = factory.q_a2e_producer() + cons_a2e = factory.q_a2e_consumer() + + prod_e2w = factory.q_e2w_producer() + cons_e2w = factory.q_e2w_consumer() + + await prod_a2e.put("msg_api") + await prod_e2w.put("msg_worker") + + msg1 = await cons_a2e.get(timeout=1500) + msg2 = await cons_e2w.get(timeout=1500) + + self.assertEqual(msg1.payload, "msg_api") + self.assertEqual(msg2.payload, "msg_worker") + + async def test_shared_context(self): + """验证 FMQFactory 始终返回同一个 context (单进程)""" + q1 = factory.q_a2e_producer() + q2 = factory.q_e2w_consumer() + q3 = factory.q_e2a_producer() + + self.assertIs(q1.context, q2.context) + self.assertIs(q1.context, q3.context) + + +if __name__ == "__main__": + unittest.main()