mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] [Benchmark]: add ZMQ-based FMQ implementation and benchmark tools (#5418)
* feat(fmq): add ZMQ-based FMQ implementation and benchmark tools * move FMQ_CONFIG_JSON to envs * fix top_p_candidates (#5400) Co-authored-by: freeliuzc <lzc842650834@gmail.com> * [RL] Support Rollout Routing Replay (#5321) * [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> * [Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374) * fix multi-inputs * fix threshold * fix threshold * fix * [BugFix]remove _execute_empty_input (#5396) * Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402) This reverts commit96d2d4877b. * [New][RL] Support Rollout Routing Replay (#5405) * [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl * Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)" This reverts commitc45e064f3d. * Fix XPU and NPU bug --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> * bf16 deepseek (#5379) * fix deepseek (#5410) * Update tests/inter_communicator/test_fmq_factory.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update benchmarks/benchmark_fmq.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/inter_communicator/fmq.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Co-authored-by: freeliuzc <lzc842650834@gmail.com> Co-authored-by: RAM <gstian5555@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com> Co-authored-by: 周周周 <39978853+zhoutianzi666@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
233
benchmarks/benchmark_fmq.py
Normal file
233
benchmarks/benchmark_fmq.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Producer Task
|
||||
# ============================================================
|
||||
async def producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q):
|
||||
fmq = FMQ()
|
||||
q = fmq.queue("mp_bench_latency", role="producer")
|
||||
payload = b"x" * payload_size
|
||||
|
||||
# tqdm 进度条
|
||||
pbar = tqdm(total=msg_count, desc=f"Producer-{proc_id}", position=proc_id, leave=True, disable=False)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for i in range(msg_count):
|
||||
send_ts = time.perf_counter()
|
||||
await q.put(data={"pid": proc_id, "i": i, "send_ts": send_ts, "payload": payload}, shm_threshold=shm_threshold)
|
||||
pbar.update(1)
|
||||
# pbar.write(f"send {i}")
|
||||
t1 = time.perf_counter()
|
||||
result_q.put({"producer_id": proc_id, "count": msg_count, "time": t1 - t0})
|
||||
|
||||
pbar.close()
|
||||
|
||||
# wait for 2 seconds before closing
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
def producer_process(proc_id, msg_count, payload_size, shm_threshold, result_q):
|
||||
async def run():
|
||||
await producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Consumer Task
|
||||
# ============================================================
|
||||
async def consumer_task(consumer_id, total_msgs, result_q, consumer_event):
|
||||
fmq = FMQ()
|
||||
q = fmq.queue("mp_bench_latency", role="consumer")
|
||||
consumer_event.set()
|
||||
|
||||
latencies = []
|
||||
recv = 0
|
||||
|
||||
# tqdm 显示进度
|
||||
pbar = tqdm(total=total_msgs, desc=f"Consumer-{consumer_id}", position=consumer_id + 1, leave=True, disable=False)
|
||||
|
||||
first_recv = None
|
||||
last_recv = None
|
||||
|
||||
while recv < total_msgs:
|
||||
msg = await q.get()
|
||||
recv_ts = time.perf_counter()
|
||||
if msg is None:
|
||||
pbar.write("recv None")
|
||||
continue
|
||||
if first_recv is None:
|
||||
first_recv = recv_ts
|
||||
last_recv = recv_ts
|
||||
send_ts = msg.payload["send_ts"]
|
||||
latencies.append((recv_ts - send_ts) * 1000) # ms
|
||||
pbar.update(1)
|
||||
recv += 1
|
||||
|
||||
pbar.close()
|
||||
|
||||
result_q.put(
|
||||
{"consumer_id": consumer_id, "latencies": latencies, "first_recv": first_recv, "last_recv": last_recv}
|
||||
)
|
||||
|
||||
|
||||
def consumer_process(consumer_id, total_msgs, result_q, consumer_event):
|
||||
async def run():
|
||||
await consumer_task(consumer_id, total_msgs, result_q, consumer_event)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
# ============================================================
|
||||
# MAIN benchmark
|
||||
# ============================================================
|
||||
def run_benchmark(
|
||||
NUM_PRODUCERS=1,
|
||||
NUM_CONSUMERS=1,
|
||||
NUM_MESSAGES_PER_PRODUCER=1000,
|
||||
PAYLOAD_SIZE=1 * 1024 * 1024,
|
||||
SHM_THRESHOLD=1 * 1024 * 1024,
|
||||
):
|
||||
total_messages = NUM_PRODUCERS * NUM_MESSAGES_PER_PRODUCER
|
||||
total_bytes = total_messages * PAYLOAD_SIZE
|
||||
|
||||
print(f"\nFastDeploy Message Queue Benchmark, pid:{os.getpid()}")
|
||||
print(f"Producers: {NUM_PRODUCERS}")
|
||||
print(f"Consumers: {NUM_CONSUMERS}")
|
||||
print(f"Messages per producer: {NUM_MESSAGES_PER_PRODUCER}")
|
||||
print(f"Total bytes: {total_bytes / 1024 / 1024 / 1024:.2f} GB")
|
||||
print(f"Total messages: {total_messages:,}")
|
||||
print(f"Payload per message: {PAYLOAD_SIZE / 1024 / 1024:.2f} MB")
|
||||
|
||||
mp.set_start_method("fork")
|
||||
manager = mp.Manager()
|
||||
result_q = manager.Queue()
|
||||
|
||||
# 两个信号事件
|
||||
consumer_event = manager.Event()
|
||||
|
||||
procs = []
|
||||
|
||||
# Start Consumers
|
||||
msgs_per_consumer = total_messages // NUM_CONSUMERS
|
||||
for i in range(NUM_CONSUMERS):
|
||||
p = mp.Process(target=consumer_process, args=(i, msgs_per_consumer, result_q, consumer_event))
|
||||
procs.append(p)
|
||||
p.start()
|
||||
|
||||
consumer_event.wait()
|
||||
|
||||
# Start Producers
|
||||
for i in range(NUM_PRODUCERS):
|
||||
p = mp.Process(
|
||||
target=producer_process, args=(i, NUM_MESSAGES_PER_PRODUCER, PAYLOAD_SIZE, SHM_THRESHOLD, result_q)
|
||||
)
|
||||
procs.append(p)
|
||||
p.start()
|
||||
|
||||
# Join
|
||||
for p in procs:
|
||||
p.join()
|
||||
|
||||
# Collect results
|
||||
producer_stats = []
|
||||
consumer_stats = {}
|
||||
|
||||
while not result_q.empty():
|
||||
item = result_q.get()
|
||||
if "producer_id" in item:
|
||||
producer_stats.append(item)
|
||||
if "consumer_id" in item:
|
||||
consumer_stats[item["consumer_id"]] = item
|
||||
|
||||
# Producer stats
|
||||
print("\nProducer Stats:")
|
||||
for p in producer_stats:
|
||||
throughput = p["count"] / p["time"]
|
||||
bandwidth = (p["count"] * PAYLOAD_SIZE) / (1024**2 * p["time"])
|
||||
print(
|
||||
f"[Producer-{p['producer_id']}] Sent {p['count']:,} msgs "
|
||||
f"in {p['time']:.3f} s | Throughput: {throughput:,.0f} msg/s | Bandwidth: {bandwidth:.2f} MB/s"
|
||||
)
|
||||
|
||||
# Consumer latency stats
|
||||
print("\nConsumer Latency Stats:")
|
||||
all_latencies = []
|
||||
first_recv_times = []
|
||||
last_recv_times = []
|
||||
|
||||
for cid, data in consumer_stats.items():
|
||||
lats = data["latencies"]
|
||||
if len(lats) == 0:
|
||||
continue
|
||||
all_latencies.extend(lats)
|
||||
first_recv_times.append(data["first_recv"])
|
||||
last_recv_times.append(data["last_recv"])
|
||||
|
||||
avg = statistics.mean(lats)
|
||||
p50 = statistics.median(lats)
|
||||
p95 = statistics.quantiles(lats, n=20)[18]
|
||||
p99 = statistics.quantiles(lats, n=100)[98]
|
||||
|
||||
print(
|
||||
f"[Consumer-{cid}] msgs={len(lats):5d} | avg={avg:.3f} ms | "
|
||||
f"P50={p50:.3f} ms | P95={p95:.3f} ms | P99={p99:.3f} ms"
|
||||
)
|
||||
|
||||
# Global summary
|
||||
if first_recv_times and last_recv_times:
|
||||
total_time = max(last_recv_times) - min(first_recv_times)
|
||||
global_throughput = total_messages / total_time
|
||||
global_bandwidth = total_bytes / (1024**2 * total_time)
|
||||
|
||||
if all_latencies:
|
||||
avg_latency = statistics.mean(all_latencies)
|
||||
min_latency = min(all_latencies)
|
||||
max_latency = max(all_latencies)
|
||||
p50_latency = statistics.median(all_latencies)
|
||||
p95_latency = statistics.quantiles(all_latencies, n=20)[18]
|
||||
p99_latency = statistics.quantiles(all_latencies, n=100)[98]
|
||||
else:
|
||||
avg_latency = min_latency = max_latency = p50_latency = p95_latency = p99_latency = 0.0
|
||||
|
||||
print("\nGlobal Summary:")
|
||||
print(f"Total messages : {total_messages:,}")
|
||||
print(f"Total data : {total_bytes / 1024**2:.2f} MB")
|
||||
print(f"Total time : {total_time:.3f} s")
|
||||
print(f"Global throughput: {global_throughput:,.0f} msg/s")
|
||||
print(f"Global bandwidth : {global_bandwidth:.2f} MB/s")
|
||||
print(
|
||||
f"Latency (ms) : avg={avg_latency:.3f} "
|
||||
f"| min={min_latency:.3f} | max={max_latency:.3f} "
|
||||
f"| P50={p50_latency:.3f} | P95={p95_latency:.3f} | P99={p99_latency:.3f}\n"
|
||||
)
|
||||
|
||||
|
||||
# Entry
|
||||
if __name__ == "__main__":
|
||||
run_benchmark()
|
||||
@@ -151,6 +151,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
||||
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
|
||||
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
||||
"FMQ_CONFIG_JSON": lambda: os.getenv("FMQ_CONFIG_JSON", None),
|
||||
}
|
||||
|
||||
|
||||
|
||||
347
fastdeploy/inter_communicator/fmq.py
Normal file
347
fastdeploy/inter_communicator/fmq.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from multiprocessing import shared_memory
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import fmq_logger
|
||||
|
||||
# ==========================
|
||||
# Config & Enum Definitions
|
||||
# ==========================
|
||||
|
||||
|
||||
class EndpointType(Enum):
|
||||
QUEUE = "queue"
|
||||
TOPIC = "topic"
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
PRODUCER = "producer"
|
||||
CONSUMER = "consumer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketOptions:
|
||||
sndhwm: int = 0
|
||||
rcvhwm: int = 0
|
||||
linger: int = -1
|
||||
sndbuf: int = 32 * 1024 * 1024
|
||||
rcvbuf: int = 32 * 1024 * 1024
|
||||
immediate: int = 1
|
||||
|
||||
def apply(self, socket: zmq.Socket, is_producer: bool):
|
||||
# Apply socket-level configurations
|
||||
socket.setsockopt(zmq.LINGER, self.linger)
|
||||
socket.setsockopt(zmq.IMMEDIATE, self.immediate)
|
||||
|
||||
if is_producer:
|
||||
socket.setsockopt(zmq.SNDHWM, self.sndhwm)
|
||||
socket.setsockopt(zmq.SNDBUF, self.sndbuf)
|
||||
else:
|
||||
socket.setsockopt(zmq.RCVHWM, self.rcvhwm)
|
||||
socket.setsockopt(zmq.RCVBUF, self.rcvbuf)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Endpoint:
|
||||
# Represents a single endpoint with protocol, address, io_threads, and copy behavior
|
||||
protocol: str
|
||||
address: str
|
||||
io_threads: int = 1
|
||||
copy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
ipc_root: str = "/dev/shm"
|
||||
io_threads: int = 1
|
||||
copy: bool = False
|
||||
endpoints: Dict[str, Endpoint] = field(default_factory=dict)
|
||||
socket_config: SocketOptions = SocketOptions()
|
||||
|
||||
|
||||
# ==========================
|
||||
# Endpoint Manager
|
||||
# ==========================
|
||||
|
||||
|
||||
class EndpointManager:
|
||||
config: Config = Config()
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, _ignored_file_path: str = None):
|
||||
cfg_str = envs.FMQ_CONFIG_JSON
|
||||
if cfg_str:
|
||||
try:
|
||||
custom_cfg = json.loads(cfg_str)
|
||||
for key, value in vars(custom_cfg).items():
|
||||
if value is not None:
|
||||
setattr(cls.config, key, value)
|
||||
except Exception as e:
|
||||
fmq_logger.error(f"Failed to load FMQ config: {e}")
|
||||
fmq_logger.info(f"Loaded FMQ config: {cls.config}")
|
||||
|
||||
@classmethod
|
||||
def get_endpoint(cls, name: str) -> Endpoint:
|
||||
# Retrieve endpoint object
|
||||
if name in cls.config.endpoints:
|
||||
return cls.config.endpoints[name]
|
||||
|
||||
# Fallback: auto-generate endpoint
|
||||
address = f"{cls.config.ipc_root}/fmq_{name}.ipc"
|
||||
return Endpoint(protocol="ipc", address=address)
|
||||
|
||||
|
||||
# ==========================
|
||||
# Shared Memory Descriptor
|
||||
# ==========================
|
||||
|
||||
|
||||
@dataclass
|
||||
class Descriptor:
|
||||
shm_name: str
|
||||
size: int
|
||||
|
||||
@staticmethod
|
||||
def create(data_bytes: bytes) -> "Descriptor":
|
||||
# Create shared memory buffer and store payload
|
||||
name = f"fmq_shm_{uuid.uuid4().hex}"
|
||||
shm = shared_memory.SharedMemory(create=True, size=len(data_bytes), name=name)
|
||||
shm.buf[: len(data_bytes)] = data_bytes
|
||||
shm.close()
|
||||
return Descriptor(shm_name=name, size=len(data_bytes))
|
||||
|
||||
def read_and_unlink(self) -> bytes:
|
||||
# Read and cleanup shared memory
|
||||
try:
|
||||
shm = shared_memory.SharedMemory(name=self.shm_name)
|
||||
data = bytes(shm.buf[: self.size])
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
return b""
|
||||
|
||||
|
||||
# ==========================
|
||||
# Message Wrapper
|
||||
# ==========================
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
payload: Any
|
||||
msg_id: int = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
descriptor: Optional[Descriptor] = None
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
# Serialize message
|
||||
return ForkingPickler.dumps(self)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> "Message":
|
||||
# Deserialize message
|
||||
return ForkingPickler.loads(data)
|
||||
|
||||
|
||||
# ==========================
|
||||
# Base Component
|
||||
# ==========================
|
||||
|
||||
|
||||
class BaseComponent:
|
||||
def __init__(self, context: zmq.asyncio.Context, endpoint: Endpoint):
|
||||
self.context = context
|
||||
self.endpoint = endpoint
|
||||
self.socket = None
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
# Close socket
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
|
||||
# ==========================
|
||||
# FIFO Queue
|
||||
# ==========================
|
||||
|
||||
|
||||
class Queue(BaseComponent):
|
||||
def __init__(self, context, name: str, role: str = "producer"):
|
||||
endpoint = EndpointManager.get_endpoint(name)
|
||||
super().__init__(context, endpoint)
|
||||
|
||||
self.name = name
|
||||
self.role = Role(role)
|
||||
self.copy = endpoint.copy
|
||||
self.socket_conf = EndpointManager.config.socket_config
|
||||
self._msg_id = 0
|
||||
|
||||
full_ep = f"{endpoint.protocol}://{endpoint.address}"
|
||||
|
||||
self.socket = self.context.socket(zmq.PUSH if self.role == Role.PRODUCER else zmq.PULL)
|
||||
self.socket_conf.apply(self.socket, self.role == Role.PRODUCER)
|
||||
|
||||
if self.role == Role.PRODUCER:
|
||||
self.socket.connect(full_ep)
|
||||
else:
|
||||
self.socket.bind(full_ep)
|
||||
|
||||
fmq_logger.info(f"Queue {name} initialized on {full_ep}")
|
||||
|
||||
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
|
||||
"""
|
||||
Send data to the queue.
|
||||
|
||||
Args:
|
||||
data: The data to send. Can be any serializable object or bytes.
|
||||
shm_threshold: Size threshold in bytes. If the data is of type bytes and its size is
|
||||
greater than or equal to this threshold, shared memory will be used to send the message.
|
||||
Default is 1MB (1024 * 1024 bytes).
|
||||
|
||||
Raises:
|
||||
PermissionError: If called by a non-producer role.
|
||||
"""
|
||||
if self.role != Role.PRODUCER:
|
||||
raise PermissionError("Only producers can send messages.")
|
||||
|
||||
desc = None
|
||||
payload = data
|
||||
|
||||
if isinstance(data, bytes) and len(data) >= shm_threshold:
|
||||
desc = Descriptor.create(data)
|
||||
payload = None
|
||||
|
||||
msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc)
|
||||
raw = msg.serialize()
|
||||
|
||||
async with self.lock:
|
||||
await self.socket.send(raw, copy=self.copy)
|
||||
self._msg_id += 1
|
||||
|
||||
async def get(self, timeout: int = None) -> Optional[Message]:
|
||||
# Receive data from queue
|
||||
if self.role != Role.CONSUMER:
|
||||
raise PermissionError("Only consumers can get messages.")
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
raw = await asyncio.wait_for(self.socket.recv(), timeout / 1000)
|
||||
else:
|
||||
raw = await self.socket.recv(copy=self.copy)
|
||||
except asyncio.TimeoutError:
|
||||
fmq_logger.error(f"Timeout receiving message on {self.name}")
|
||||
return None
|
||||
|
||||
msg = Message.deserialize(raw)
|
||||
if msg.descriptor:
|
||||
msg.payload = msg.descriptor.read_and_unlink()
|
||||
|
||||
self._msg_id += 1
|
||||
return msg
|
||||
|
||||
|
||||
# ==========================
|
||||
# Pub/Sub Topic
|
||||
# ==========================
|
||||
|
||||
|
||||
class Topic(BaseComponent):
|
||||
def __init__(self, context, name: str):
|
||||
endpoint = EndpointManager.get_endpoint(name)
|
||||
super().__init__(context, endpoint)
|
||||
self.name = name
|
||||
self._pub_socket = None
|
||||
self._sub_socket = None
|
||||
self._task = None
|
||||
|
||||
async def pub(self, data: Any):
|
||||
# Publish a message
|
||||
if not self._pub_socket:
|
||||
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
|
||||
self._pub_socket = self.context.socket(zmq.PUB)
|
||||
self._pub_socket.bind(ep)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
msg = Message(payload=data)
|
||||
async with self.lock:
|
||||
await self._pub_socket.send(msg.serialize())
|
||||
|
||||
async def sub(self, callback: Callable[[Message], Any]):
|
||||
# Subscribe and handle messages
|
||||
if not self._sub_socket:
|
||||
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
|
||||
self._sub_socket = self.context.socket(zmq.SUB)
|
||||
self._sub_socket.connect(ep)
|
||||
self._sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
|
||||
async def loop():
|
||||
while True:
|
||||
raw = await self._sub_socket.recv()
|
||||
msg = Message.deserialize(raw)
|
||||
result = callback(msg)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
self._task = asyncio.create_task(loop())
|
||||
|
||||
|
||||
# ==========================
|
||||
# FMQ Main Interface
|
||||
# ==========================
|
||||
|
||||
|
||||
class FMQ:
|
||||
_instance = None
|
||||
_context = None
|
||||
|
||||
def __new__(cls, config_path="fmq_config.json"):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
EndpointManager.load_config()
|
||||
|
||||
# Determine IO threads based on global defaults
|
||||
io_threads = 1
|
||||
if EndpointManager.config.endpoints:
|
||||
# Use max io_threads among all endpoints
|
||||
io_threads = max(ep.io_threads for ep in EndpointManager.config.endpoints.values())
|
||||
|
||||
cls._context = zmq.asyncio.Context(io_threads=io_threads)
|
||||
return cls._instance
|
||||
|
||||
def queue(self, name: str, role="producer") -> Queue:
|
||||
return Queue(self._context, name, role)
|
||||
|
||||
def topic(self, name: str) -> Topic:
|
||||
return Topic(self._context, name)
|
||||
|
||||
async def destroy(self):
|
||||
# Destroy ZeroMQ context
|
||||
self._context.term()
|
||||
83
fastdeploy/inter_communicator/fmq_factory.py
Normal file
83
fastdeploy/inter_communicator/fmq_factory.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
|
||||
|
||||
class FMQFactory:
|
||||
"""
|
||||
Static factory for creating the four standard FMQ queues:
|
||||
1. q_a2e: api server --> engine
|
||||
2. q_e2w: engine --> worker
|
||||
3. q_w2e: worker --> engine
|
||||
4. q_e2a: engine --> api server
|
||||
API Server: q_a2e producer / q_e2a consumer
|
||||
Engine: q_a2e consumer / q_e2w producer / q_w2e consumer / q_e2a producer
|
||||
Worker: q_e2w consumer / q_w2e producer
|
||||
"""
|
||||
|
||||
_fmq = FMQ()
|
||||
|
||||
# ------------------------------
|
||||
# API → Engine
|
||||
# ------------------------------
|
||||
@classmethod
|
||||
def q_a2e_producer(cls):
|
||||
return cls._fmq.queue("q_a2e", role="producer")
|
||||
|
||||
@classmethod
|
||||
def q_a2e_consumer(cls):
|
||||
return cls._fmq.queue("q_a2e", role="consumer")
|
||||
|
||||
# ------------------------------
|
||||
# Engine → Worker
|
||||
# ------------------------------
|
||||
@classmethod
|
||||
def q_e2w_producer(cls):
|
||||
return cls._fmq.queue("q_e2w", role="producer")
|
||||
|
||||
@classmethod
|
||||
def q_e2w_consumer(cls):
|
||||
return cls._fmq.queue("q_e2w", role="consumer")
|
||||
|
||||
# ------------------------------
|
||||
# Worker → Engine
|
||||
# ------------------------------
|
||||
@classmethod
|
||||
def q_w2e_producer(cls):
|
||||
return cls._fmq.queue("q_w2e", role="producer")
|
||||
|
||||
@classmethod
|
||||
def q_w2e_consumer(cls):
|
||||
return cls._fmq.queue("q_w2e", role="consumer")
|
||||
|
||||
# ------------------------------
|
||||
# Engine → API
|
||||
# ------------------------------
|
||||
@classmethod
|
||||
def q_e2a_producer(cls):
|
||||
return cls._fmq.queue("q_e2a", role="producer")
|
||||
|
||||
@classmethod
|
||||
def q_e2a_consumer(cls):
|
||||
return cls._fmq.queue("q_e2a", role="consumer")
|
||||
|
||||
# ------------------------------
|
||||
# Destroy context
|
||||
# ------------------------------
|
||||
@classmethod
|
||||
async def destroy(cls):
|
||||
await cls._fmq.destroy()
|
||||
@@ -1051,6 +1051,7 @@ spec_logger = get_logger("speculate", "speculate.log")
|
||||
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
||||
trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log")
|
||||
router_logger = get_logger("router", "router.log")
|
||||
fmq_logger = get_logger("fmq", "fmq.log")
|
||||
|
||||
|
||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
92
tests/inter_communicator/test_fmq.py
Normal file
92
tests/inter_communicator/test_fmq.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastdeploy.inter_communicator.fmq import FMQ, Message
|
||||
|
||||
# Prepare environment config for testing
|
||||
cfg = {
|
||||
"ipc_root": "/dev/shm",
|
||||
"io_threads": 1,
|
||||
"copy": False,
|
||||
"endpoints": {
|
||||
"test_queue": {"protocol": "ipc", "address": "/dev/shm/fmq_test_queue.ipc", "io_threads": 1, "copy": False},
|
||||
"test_topic": {"protocol": "ipc", "address": "/dev/shm/fmq_test_topic.ipc", "io_threads": 1, "copy": False},
|
||||
},
|
||||
}
|
||||
os.environ["FMQ_CONFIG_JSON"] = json.dumps(cfg)
|
||||
|
||||
|
||||
class TestFMQ(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.fmq = FMQ()
|
||||
|
||||
def test_queue_send_receive(self):
|
||||
async def run_test():
|
||||
producer = self.fmq.queue("test_queue", role="producer")
|
||||
consumer = self.fmq.queue("test_queue", role="consumer")
|
||||
|
||||
test_data = b"hello world"
|
||||
await producer.put(test_data)
|
||||
msg = await consumer.get(timeout=1000)
|
||||
|
||||
self.assertIsNotNone(msg)
|
||||
self.assertEqual(msg.payload, test_data)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_queue_large_shm_transfer(self):
|
||||
async def run_test():
|
||||
producer = self.fmq.queue("test_queue", role="producer")
|
||||
consumer = self.fmq.queue("test_queue", role="consumer")
|
||||
|
||||
large_data = b"x" * (2 * 1024 * 1024) # > 1MB
|
||||
await producer.put(large_data)
|
||||
msg = await consumer.get(timeout=1000)
|
||||
|
||||
self.assertIsNotNone(msg)
|
||||
self.assertEqual(msg.payload, large_data)
|
||||
self.assertIsNotNone(msg.descriptor)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_topic_pub_sub(self):
|
||||
received = []
|
||||
|
||||
async def run_test():
|
||||
topic = self.fmq.topic("test_topic")
|
||||
|
||||
async def callback(msg: Message):
|
||||
received.append(msg.payload)
|
||||
|
||||
await topic.sub(callback)
|
||||
await asyncio.sleep(0.1) # allow SUB to connect
|
||||
|
||||
await topic.pub("hello")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
self.assertIn("hello", received)
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
91
tests/inter_communicator/test_fmq_factory.py
Normal file
91
tests/inter_communicator/test_fmq_factory.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from fastdeploy.inter_communicator.fmq import Message
|
||||
from fastdeploy.inter_communicator.fmq_factory import FMQFactory as factory
|
||||
|
||||
|
||||
class TestFMQFactory(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_create_queues(self):
|
||||
"""Test whether all producer/consumer queues can be created."""
|
||||
q1 = factory.q_a2e_producer()
|
||||
q2 = factory.q_a2e_consumer()
|
||||
q3 = factory.q_e2w_producer()
|
||||
q4 = factory.q_e2w_consumer()
|
||||
q5 = factory.q_w2e_producer()
|
||||
q6 = factory.q_w2e_consumer()
|
||||
q7 = factory.q_e2a_producer()
|
||||
q8 = factory.q_e2a_consumer()
|
||||
|
||||
self.assertEqual(q1.name, "q_a2e")
|
||||
self.assertEqual(q2.name, "q_a2e")
|
||||
self.assertEqual(q3.name, "q_e2w")
|
||||
self.assertEqual(q4.name, "q_e2w")
|
||||
self.assertEqual(q5.name, "q_w2e")
|
||||
self.assertEqual(q6.name, "q_w2e")
|
||||
self.assertEqual(q7.name, "q_e2a")
|
||||
self.assertEqual(q8.name, "q_e2a")
|
||||
|
||||
# 同一进程内 context 应相同
|
||||
self.assertIs(q1.context, q2.context)
|
||||
self.assertIs(q1.context, q3.context)
|
||||
|
||||
async def test_message_roundtrip(self):
|
||||
"""测试 producer → consumer 消息流转"""
|
||||
producer = factory.q_a2e_producer()
|
||||
consumer = factory.q_a2e_consumer()
|
||||
|
||||
payload = {"k": "v"}
|
||||
|
||||
await producer.put(payload)
|
||||
msg = await consumer.get(timeout=1500)
|
||||
|
||||
self.assertIsInstance(msg, Message)
|
||||
self.assertEqual(msg.payload, payload)
|
||||
|
||||
async def test_multi_queue_independence(self):
|
||||
"""测试多个队列互不干扰"""
|
||||
|
||||
prod_a2e = factory.q_a2e_producer()
|
||||
cons_a2e = factory.q_a2e_consumer()
|
||||
|
||||
prod_e2w = factory.q_e2w_producer()
|
||||
cons_e2w = factory.q_e2w_consumer()
|
||||
|
||||
await prod_a2e.put("msg_api")
|
||||
await prod_e2w.put("msg_worker")
|
||||
|
||||
msg1 = await cons_a2e.get(timeout=1500)
|
||||
msg2 = await cons_e2w.get(timeout=1500)
|
||||
|
||||
self.assertEqual(msg1.payload, "msg_api")
|
||||
self.assertEqual(msg2.payload, "msg_worker")
|
||||
|
||||
async def test_shared_context(self):
|
||||
"""验证 FMQFactory 始终返回同一个 context (单进程)"""
|
||||
q1 = factory.q_a2e_producer()
|
||||
q2 = factory.q_e2w_consumer()
|
||||
q3 = factory.q_e2a_producer()
|
||||
|
||||
self.assertIs(q1.context, q2.context)
|
||||
self.assertIs(q1.context, q3.context)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user