mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] [Benchmark]: add ZMQ-based FMQ implementation and benchmark tools (#5418)
* feat(fmq): add ZMQ-based FMQ implementation and benchmark tools * move FMQ_CONFIG_JSON to envs * fix top_p_candidates (#5400) Co-authored-by: freeliuzc <lzc842650834@gmail.com> * [RL] Support Rollout Routing Replay (#5321) * [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> * [Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374) * fix multi-inputs * fix threshold * fix threshold * fix * [BugFix]remove _execute_empty_input (#5396) * Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402) This reverts commit96d2d4877b. * [New][RL] Support Rollout Routing Replay (#5405) * [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl * Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)" This reverts commitc45e064f3d. * Fix XPU and NPU bug --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> * bf16 deepseek (#5379) * fix deepseek (#5410) * Update tests/inter_communicator/test_fmq_factory.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update benchmarks/benchmark_fmq.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/inter_communicator/fmq.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Co-authored-by: freeliuzc <lzc842650834@gmail.com> Co-authored-by: RAM <gstian5555@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com> Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com> Co-authored-by: 周周周 <39978853+zhoutianzi666@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
233
benchmarks/benchmark_fmq.py
Normal file
233
benchmarks/benchmark_fmq.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator.fmq import FMQ
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Producer Task
|
||||||
|
# ============================================================
|
||||||
|
async def producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q):
|
||||||
|
fmq = FMQ()
|
||||||
|
q = fmq.queue("mp_bench_latency", role="producer")
|
||||||
|
payload = b"x" * payload_size
|
||||||
|
|
||||||
|
# tqdm 进度条
|
||||||
|
pbar = tqdm(total=msg_count, desc=f"Producer-{proc_id}", position=proc_id, leave=True, disable=False)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
for i in range(msg_count):
|
||||||
|
send_ts = time.perf_counter()
|
||||||
|
await q.put(data={"pid": proc_id, "i": i, "send_ts": send_ts, "payload": payload}, shm_threshold=shm_threshold)
|
||||||
|
pbar.update(1)
|
||||||
|
# pbar.write(f"send {i}")
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
result_q.put({"producer_id": proc_id, "count": msg_count, "time": t1 - t0})
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
# wait for 2 seconds before closing
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def producer_process(proc_id, msg_count, payload_size, shm_threshold, result_q):
|
||||||
|
async def run():
|
||||||
|
await producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q)
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Consumer Task
|
||||||
|
# ============================================================
|
||||||
|
async def consumer_task(consumer_id, total_msgs, result_q, consumer_event):
|
||||||
|
fmq = FMQ()
|
||||||
|
q = fmq.queue("mp_bench_latency", role="consumer")
|
||||||
|
consumer_event.set()
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
recv = 0
|
||||||
|
|
||||||
|
# tqdm 显示进度
|
||||||
|
pbar = tqdm(total=total_msgs, desc=f"Consumer-{consumer_id}", position=consumer_id + 1, leave=True, disable=False)
|
||||||
|
|
||||||
|
first_recv = None
|
||||||
|
last_recv = None
|
||||||
|
|
||||||
|
while recv < total_msgs:
|
||||||
|
msg = await q.get()
|
||||||
|
recv_ts = time.perf_counter()
|
||||||
|
if msg is None:
|
||||||
|
pbar.write("recv None")
|
||||||
|
continue
|
||||||
|
if first_recv is None:
|
||||||
|
first_recv = recv_ts
|
||||||
|
last_recv = recv_ts
|
||||||
|
send_ts = msg.payload["send_ts"]
|
||||||
|
latencies.append((recv_ts - send_ts) * 1000) # ms
|
||||||
|
pbar.update(1)
|
||||||
|
recv += 1
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
result_q.put(
|
||||||
|
{"consumer_id": consumer_id, "latencies": latencies, "first_recv": first_recv, "last_recv": last_recv}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def consumer_process(consumer_id, total_msgs, result_q, consumer_event):
|
||||||
|
async def run():
|
||||||
|
await consumer_task(consumer_id, total_msgs, result_q, consumer_event)
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# MAIN benchmark
|
||||||
|
# ============================================================
|
||||||
|
def run_benchmark(
|
||||||
|
NUM_PRODUCERS=1,
|
||||||
|
NUM_CONSUMERS=1,
|
||||||
|
NUM_MESSAGES_PER_PRODUCER=1000,
|
||||||
|
PAYLOAD_SIZE=1 * 1024 * 1024,
|
||||||
|
SHM_THRESHOLD=1 * 1024 * 1024,
|
||||||
|
):
|
||||||
|
total_messages = NUM_PRODUCERS * NUM_MESSAGES_PER_PRODUCER
|
||||||
|
total_bytes = total_messages * PAYLOAD_SIZE
|
||||||
|
|
||||||
|
print(f"\nFastDeploy Message Queue Benchmark, pid:{os.getpid()}")
|
||||||
|
print(f"Producers: {NUM_PRODUCERS}")
|
||||||
|
print(f"Consumers: {NUM_CONSUMERS}")
|
||||||
|
print(f"Messages per producer: {NUM_MESSAGES_PER_PRODUCER}")
|
||||||
|
print(f"Total bytes: {total_bytes / 1024 / 1024 / 1024:.2f} GB")
|
||||||
|
print(f"Total messages: {total_messages:,}")
|
||||||
|
print(f"Payload per message: {PAYLOAD_SIZE / 1024 / 1024:.2f} MB")
|
||||||
|
|
||||||
|
mp.set_start_method("fork")
|
||||||
|
manager = mp.Manager()
|
||||||
|
result_q = manager.Queue()
|
||||||
|
|
||||||
|
# 两个信号事件
|
||||||
|
consumer_event = manager.Event()
|
||||||
|
|
||||||
|
procs = []
|
||||||
|
|
||||||
|
# Start Consumers
|
||||||
|
msgs_per_consumer = total_messages // NUM_CONSUMERS
|
||||||
|
for i in range(NUM_CONSUMERS):
|
||||||
|
p = mp.Process(target=consumer_process, args=(i, msgs_per_consumer, result_q, consumer_event))
|
||||||
|
procs.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
consumer_event.wait()
|
||||||
|
|
||||||
|
# Start Producers
|
||||||
|
for i in range(NUM_PRODUCERS):
|
||||||
|
p = mp.Process(
|
||||||
|
target=producer_process, args=(i, NUM_MESSAGES_PER_PRODUCER, PAYLOAD_SIZE, SHM_THRESHOLD, result_q)
|
||||||
|
)
|
||||||
|
procs.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
# Join
|
||||||
|
for p in procs:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
producer_stats = []
|
||||||
|
consumer_stats = {}
|
||||||
|
|
||||||
|
while not result_q.empty():
|
||||||
|
item = result_q.get()
|
||||||
|
if "producer_id" in item:
|
||||||
|
producer_stats.append(item)
|
||||||
|
if "consumer_id" in item:
|
||||||
|
consumer_stats[item["consumer_id"]] = item
|
||||||
|
|
||||||
|
# Producer stats
|
||||||
|
print("\nProducer Stats:")
|
||||||
|
for p in producer_stats:
|
||||||
|
throughput = p["count"] / p["time"]
|
||||||
|
bandwidth = (p["count"] * PAYLOAD_SIZE) / (1024**2 * p["time"])
|
||||||
|
print(
|
||||||
|
f"[Producer-{p['producer_id']}] Sent {p['count']:,} msgs "
|
||||||
|
f"in {p['time']:.3f} s | Throughput: {throughput:,.0f} msg/s | Bandwidth: {bandwidth:.2f} MB/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consumer latency stats
|
||||||
|
print("\nConsumer Latency Stats:")
|
||||||
|
all_latencies = []
|
||||||
|
first_recv_times = []
|
||||||
|
last_recv_times = []
|
||||||
|
|
||||||
|
for cid, data in consumer_stats.items():
|
||||||
|
lats = data["latencies"]
|
||||||
|
if len(lats) == 0:
|
||||||
|
continue
|
||||||
|
all_latencies.extend(lats)
|
||||||
|
first_recv_times.append(data["first_recv"])
|
||||||
|
last_recv_times.append(data["last_recv"])
|
||||||
|
|
||||||
|
avg = statistics.mean(lats)
|
||||||
|
p50 = statistics.median(lats)
|
||||||
|
p95 = statistics.quantiles(lats, n=20)[18]
|
||||||
|
p99 = statistics.quantiles(lats, n=100)[98]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[Consumer-{cid}] msgs={len(lats):5d} | avg={avg:.3f} ms | "
|
||||||
|
f"P50={p50:.3f} ms | P95={p95:.3f} ms | P99={p99:.3f} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global summary
|
||||||
|
if first_recv_times and last_recv_times:
|
||||||
|
total_time = max(last_recv_times) - min(first_recv_times)
|
||||||
|
global_throughput = total_messages / total_time
|
||||||
|
global_bandwidth = total_bytes / (1024**2 * total_time)
|
||||||
|
|
||||||
|
if all_latencies:
|
||||||
|
avg_latency = statistics.mean(all_latencies)
|
||||||
|
min_latency = min(all_latencies)
|
||||||
|
max_latency = max(all_latencies)
|
||||||
|
p50_latency = statistics.median(all_latencies)
|
||||||
|
p95_latency = statistics.quantiles(all_latencies, n=20)[18]
|
||||||
|
p99_latency = statistics.quantiles(all_latencies, n=100)[98]
|
||||||
|
else:
|
||||||
|
avg_latency = min_latency = max_latency = p50_latency = p95_latency = p99_latency = 0.0
|
||||||
|
|
||||||
|
print("\nGlobal Summary:")
|
||||||
|
print(f"Total messages : {total_messages:,}")
|
||||||
|
print(f"Total data : {total_bytes / 1024**2:.2f} MB")
|
||||||
|
print(f"Total time : {total_time:.3f} s")
|
||||||
|
print(f"Global throughput: {global_throughput:,.0f} msg/s")
|
||||||
|
print(f"Global bandwidth : {global_bandwidth:.2f} MB/s")
|
||||||
|
print(
|
||||||
|
f"Latency (ms) : avg={avg_latency:.3f} "
|
||||||
|
f"| min={min_latency:.3f} | max={max_latency:.3f} "
|
||||||
|
f"| P50={p50_latency:.3f} | P95={p95_latency:.3f} | P99={p99_latency:.3f}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Entry
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_benchmark()
|
||||||
@@ -151,6 +151,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
||||||
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
|
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
|
||||||
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
||||||
|
"FMQ_CONFIG_JSON": lambda: os.getenv("FMQ_CONFIG_JSON", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
347
fastdeploy/inter_communicator/fmq.py
Normal file
347
fastdeploy/inter_communicator/fmq.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
from multiprocessing.reduction import ForkingPickler
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
|
from fastdeploy.utils import fmq_logger
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Config & Enum Definitions
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointType(Enum):
|
||||||
|
QUEUE = "queue"
|
||||||
|
TOPIC = "topic"
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Enum):
|
||||||
|
PRODUCER = "producer"
|
||||||
|
CONSUMER = "consumer"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SocketOptions:
|
||||||
|
sndhwm: int = 0
|
||||||
|
rcvhwm: int = 0
|
||||||
|
linger: int = -1
|
||||||
|
sndbuf: int = 32 * 1024 * 1024
|
||||||
|
rcvbuf: int = 32 * 1024 * 1024
|
||||||
|
immediate: int = 1
|
||||||
|
|
||||||
|
def apply(self, socket: zmq.Socket, is_producer: bool):
|
||||||
|
# Apply socket-level configurations
|
||||||
|
socket.setsockopt(zmq.LINGER, self.linger)
|
||||||
|
socket.setsockopt(zmq.IMMEDIATE, self.immediate)
|
||||||
|
|
||||||
|
if is_producer:
|
||||||
|
socket.setsockopt(zmq.SNDHWM, self.sndhwm)
|
||||||
|
socket.setsockopt(zmq.SNDBUF, self.sndbuf)
|
||||||
|
else:
|
||||||
|
socket.setsockopt(zmq.RCVHWM, self.rcvhwm)
|
||||||
|
socket.setsockopt(zmq.RCVBUF, self.rcvbuf)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Endpoint:
|
||||||
|
# Represents a single endpoint with protocol, address, io_threads, and copy behavior
|
||||||
|
protocol: str
|
||||||
|
address: str
|
||||||
|
io_threads: int = 1
|
||||||
|
copy: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
ipc_root: str = "/dev/shm"
|
||||||
|
io_threads: int = 1
|
||||||
|
copy: bool = False
|
||||||
|
endpoints: Dict[str, Endpoint] = field(default_factory=dict)
|
||||||
|
socket_config: SocketOptions = SocketOptions()
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Endpoint Manager
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointManager:
|
||||||
|
config: Config = Config()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, _ignored_file_path: str = None):
|
||||||
|
cfg_str = envs.FMQ_CONFIG_JSON
|
||||||
|
if cfg_str:
|
||||||
|
try:
|
||||||
|
custom_cfg = json.loads(cfg_str)
|
||||||
|
for key, value in vars(custom_cfg).items():
|
||||||
|
if value is not None:
|
||||||
|
setattr(cls.config, key, value)
|
||||||
|
except Exception as e:
|
||||||
|
fmq_logger.error(f"Failed to load FMQ config: {e}")
|
||||||
|
fmq_logger.info(f"Loaded FMQ config: {cls.config}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_endpoint(cls, name: str) -> Endpoint:
|
||||||
|
# Retrieve endpoint object
|
||||||
|
if name in cls.config.endpoints:
|
||||||
|
return cls.config.endpoints[name]
|
||||||
|
|
||||||
|
# Fallback: auto-generate endpoint
|
||||||
|
address = f"{cls.config.ipc_root}/fmq_{name}.ipc"
|
||||||
|
return Endpoint(protocol="ipc", address=address)
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Shared Memory Descriptor
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Descriptor:
|
||||||
|
shm_name: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(data_bytes: bytes) -> "Descriptor":
|
||||||
|
# Create shared memory buffer and store payload
|
||||||
|
name = f"fmq_shm_{uuid.uuid4().hex}"
|
||||||
|
shm = shared_memory.SharedMemory(create=True, size=len(data_bytes), name=name)
|
||||||
|
shm.buf[: len(data_bytes)] = data_bytes
|
||||||
|
shm.close()
|
||||||
|
return Descriptor(shm_name=name, size=len(data_bytes))
|
||||||
|
|
||||||
|
def read_and_unlink(self) -> bytes:
|
||||||
|
# Read and cleanup shared memory
|
||||||
|
try:
|
||||||
|
shm = shared_memory.SharedMemory(name=self.shm_name)
|
||||||
|
data = bytes(shm.buf[: self.size])
|
||||||
|
shm.close()
|
||||||
|
shm.unlink()
|
||||||
|
return data
|
||||||
|
except FileNotFoundError:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Message Wrapper
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
payload: Any
|
||||||
|
msg_id: int = None
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
descriptor: Optional[Descriptor] = None
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
# Serialize message
|
||||||
|
return ForkingPickler.dumps(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(data: bytes) -> "Message":
|
||||||
|
# Deserialize message
|
||||||
|
return ForkingPickler.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Base Component
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class BaseComponent:
|
||||||
|
def __init__(self, context: zmq.asyncio.Context, endpoint: Endpoint):
|
||||||
|
self.context = context
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.socket = None
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
# Close socket
|
||||||
|
if self.socket:
|
||||||
|
self.socket.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# FIFO Queue
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class Queue(BaseComponent):
|
||||||
|
def __init__(self, context, name: str, role: str = "producer"):
|
||||||
|
endpoint = EndpointManager.get_endpoint(name)
|
||||||
|
super().__init__(context, endpoint)
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.role = Role(role)
|
||||||
|
self.copy = endpoint.copy
|
||||||
|
self.socket_conf = EndpointManager.config.socket_config
|
||||||
|
self._msg_id = 0
|
||||||
|
|
||||||
|
full_ep = f"{endpoint.protocol}://{endpoint.address}"
|
||||||
|
|
||||||
|
self.socket = self.context.socket(zmq.PUSH if self.role == Role.PRODUCER else zmq.PULL)
|
||||||
|
self.socket_conf.apply(self.socket, self.role == Role.PRODUCER)
|
||||||
|
|
||||||
|
if self.role == Role.PRODUCER:
|
||||||
|
self.socket.connect(full_ep)
|
||||||
|
else:
|
||||||
|
self.socket.bind(full_ep)
|
||||||
|
|
||||||
|
fmq_logger.info(f"Queue {name} initialized on {full_ep}")
|
||||||
|
|
||||||
|
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
|
||||||
|
"""
|
||||||
|
Send data to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data to send. Can be any serializable object or bytes.
|
||||||
|
shm_threshold: Size threshold in bytes. If the data is of type bytes and its size is
|
||||||
|
greater than or equal to this threshold, shared memory will be used to send the message.
|
||||||
|
Default is 1MB (1024 * 1024 bytes).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If called by a non-producer role.
|
||||||
|
"""
|
||||||
|
if self.role != Role.PRODUCER:
|
||||||
|
raise PermissionError("Only producers can send messages.")
|
||||||
|
|
||||||
|
desc = None
|
||||||
|
payload = data
|
||||||
|
|
||||||
|
if isinstance(data, bytes) and len(data) >= shm_threshold:
|
||||||
|
desc = Descriptor.create(data)
|
||||||
|
payload = None
|
||||||
|
|
||||||
|
msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc)
|
||||||
|
raw = msg.serialize()
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
await self.socket.send(raw, copy=self.copy)
|
||||||
|
self._msg_id += 1
|
||||||
|
|
||||||
|
async def get(self, timeout: int = None) -> Optional[Message]:
|
||||||
|
# Receive data from queue
|
||||||
|
if self.role != Role.CONSUMER:
|
||||||
|
raise PermissionError("Only consumers can get messages.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeout:
|
||||||
|
raw = await asyncio.wait_for(self.socket.recv(), timeout / 1000)
|
||||||
|
else:
|
||||||
|
raw = await self.socket.recv(copy=self.copy)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
fmq_logger.error(f"Timeout receiving message on {self.name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
msg = Message.deserialize(raw)
|
||||||
|
if msg.descriptor:
|
||||||
|
msg.payload = msg.descriptor.read_and_unlink()
|
||||||
|
|
||||||
|
self._msg_id += 1
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Pub/Sub Topic
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class Topic(BaseComponent):
|
||||||
|
def __init__(self, context, name: str):
|
||||||
|
endpoint = EndpointManager.get_endpoint(name)
|
||||||
|
super().__init__(context, endpoint)
|
||||||
|
self.name = name
|
||||||
|
self._pub_socket = None
|
||||||
|
self._sub_socket = None
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
async def pub(self, data: Any):
|
||||||
|
# Publish a message
|
||||||
|
if not self._pub_socket:
|
||||||
|
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
|
||||||
|
self._pub_socket = self.context.socket(zmq.PUB)
|
||||||
|
self._pub_socket.bind(ep)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
msg = Message(payload=data)
|
||||||
|
async with self.lock:
|
||||||
|
await self._pub_socket.send(msg.serialize())
|
||||||
|
|
||||||
|
async def sub(self, callback: Callable[[Message], Any]):
|
||||||
|
# Subscribe and handle messages
|
||||||
|
if not self._sub_socket:
|
||||||
|
ep = f"{self.endpoint.protocol}://{self.endpoint.address}"
|
||||||
|
self._sub_socket = self.context.socket(zmq.SUB)
|
||||||
|
self._sub_socket.connect(ep)
|
||||||
|
self._sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
|
||||||
|
async def loop():
|
||||||
|
while True:
|
||||||
|
raw = await self._sub_socket.recv()
|
||||||
|
msg = Message.deserialize(raw)
|
||||||
|
result = callback(msg)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
|
||||||
|
self._task = asyncio.create_task(loop())
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# FMQ Main Interface
|
||||||
|
# ==========================
|
||||||
|
|
||||||
|
|
||||||
|
class FMQ:
|
||||||
|
_instance = None
|
||||||
|
_context = None
|
||||||
|
|
||||||
|
def __new__(cls, config_path="fmq_config.json"):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
EndpointManager.load_config()
|
||||||
|
|
||||||
|
# Determine IO threads based on global defaults
|
||||||
|
io_threads = 1
|
||||||
|
if EndpointManager.config.endpoints:
|
||||||
|
# Use max io_threads among all endpoints
|
||||||
|
io_threads = max(ep.io_threads for ep in EndpointManager.config.endpoints.values())
|
||||||
|
|
||||||
|
cls._context = zmq.asyncio.Context(io_threads=io_threads)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def queue(self, name: str, role="producer") -> Queue:
|
||||||
|
return Queue(self._context, name, role)
|
||||||
|
|
||||||
|
def topic(self, name: str) -> Topic:
|
||||||
|
return Topic(self._context, name)
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
# Destroy ZeroMQ context
|
||||||
|
self._context.term()
|
||||||
83
fastdeploy/inter_communicator/fmq_factory.py
Normal file
83
fastdeploy/inter_communicator/fmq_factory.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator.fmq import FMQ
|
||||||
|
|
||||||
|
|
||||||
|
class FMQFactory:
|
||||||
|
"""
|
||||||
|
Static factory for creating the four standard FMQ queues:
|
||||||
|
1. q_a2e: api server --> engine
|
||||||
|
2. q_e2w: engine --> worker
|
||||||
|
3. q_w2e: worker --> engine
|
||||||
|
4. q_e2a: engine --> api server
|
||||||
|
API Server: q_a2e producer / q_e2a consumer
|
||||||
|
Engine: q_a2e consumer / q_e2w producer / q_w2e consumer / q_e2a producer
|
||||||
|
Worker: q_e2w consumer / q_w2e producer
|
||||||
|
"""
|
||||||
|
|
||||||
|
_fmq = FMQ()
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# API → Engine
|
||||||
|
# ------------------------------
|
||||||
|
@classmethod
|
||||||
|
def q_a2e_producer(cls):
|
||||||
|
return cls._fmq.queue("q_a2e", role="producer")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def q_a2e_consumer(cls):
|
||||||
|
return cls._fmq.queue("q_a2e", role="consumer")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Engine → Worker
|
||||||
|
# ------------------------------
|
||||||
|
@classmethod
|
||||||
|
def q_e2w_producer(cls):
|
||||||
|
return cls._fmq.queue("q_e2w", role="producer")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def q_e2w_consumer(cls):
|
||||||
|
return cls._fmq.queue("q_e2w", role="consumer")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Worker → Engine
|
||||||
|
# ------------------------------
|
||||||
|
@classmethod
|
||||||
|
def q_w2e_producer(cls):
|
||||||
|
return cls._fmq.queue("q_w2e", role="producer")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def q_w2e_consumer(cls):
|
||||||
|
return cls._fmq.queue("q_w2e", role="consumer")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Engine → API
|
||||||
|
# ------------------------------
|
||||||
|
@classmethod
|
||||||
|
def q_e2a_producer(cls):
|
||||||
|
return cls._fmq.queue("q_e2a", role="producer")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def q_e2a_consumer(cls):
|
||||||
|
return cls._fmq.queue("q_e2a", role="consumer")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Destroy context
|
||||||
|
# ------------------------------
|
||||||
|
@classmethod
|
||||||
|
async def destroy(cls):
|
||||||
|
await cls._fmq.destroy()
|
||||||
@@ -1051,6 +1051,7 @@ spec_logger = get_logger("speculate", "speculate.log")
|
|||||||
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
|
||||||
trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log")
|
trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log")
|
||||||
router_logger = get_logger("router", "router.log")
|
router_logger = get_logger("router", "router.log")
|
||||||
|
fmq_logger = get_logger("fmq", "fmq.log")
|
||||||
|
|
||||||
|
|
||||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||||
|
|||||||
92
tests/inter_communicator/test_fmq.py
Normal file
92
tests/inter_communicator/test_fmq.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator.fmq import FMQ, Message
|
||||||
|
|
||||||
|
# Prepare environment config for testing
|
||||||
|
cfg = {
|
||||||
|
"ipc_root": "/dev/shm",
|
||||||
|
"io_threads": 1,
|
||||||
|
"copy": False,
|
||||||
|
"endpoints": {
|
||||||
|
"test_queue": {"protocol": "ipc", "address": "/dev/shm/fmq_test_queue.ipc", "io_threads": 1, "copy": False},
|
||||||
|
"test_topic": {"protocol": "ipc", "address": "/dev/shm/fmq_test_topic.ipc", "io_threads": 1, "copy": False},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
os.environ["FMQ_CONFIG_JSON"] = json.dumps(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFMQ(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.fmq = FMQ()
|
||||||
|
|
||||||
|
def test_queue_send_receive(self):
|
||||||
|
async def run_test():
|
||||||
|
producer = self.fmq.queue("test_queue", role="producer")
|
||||||
|
consumer = self.fmq.queue("test_queue", role="consumer")
|
||||||
|
|
||||||
|
test_data = b"hello world"
|
||||||
|
await producer.put(test_data)
|
||||||
|
msg = await consumer.get(timeout=1000)
|
||||||
|
|
||||||
|
self.assertIsNotNone(msg)
|
||||||
|
self.assertEqual(msg.payload, test_data)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
def test_queue_large_shm_transfer(self):
|
||||||
|
async def run_test():
|
||||||
|
producer = self.fmq.queue("test_queue", role="producer")
|
||||||
|
consumer = self.fmq.queue("test_queue", role="consumer")
|
||||||
|
|
||||||
|
large_data = b"x" * (2 * 1024 * 1024) # > 1MB
|
||||||
|
await producer.put(large_data)
|
||||||
|
msg = await consumer.get(timeout=1000)
|
||||||
|
|
||||||
|
self.assertIsNotNone(msg)
|
||||||
|
self.assertEqual(msg.payload, large_data)
|
||||||
|
self.assertIsNotNone(msg.descriptor)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
def test_topic_pub_sub(self):
|
||||||
|
received = []
|
||||||
|
|
||||||
|
async def run_test():
|
||||||
|
topic = self.fmq.topic("test_topic")
|
||||||
|
|
||||||
|
async def callback(msg: Message):
|
||||||
|
received.append(msg.payload)
|
||||||
|
|
||||||
|
await topic.sub(callback)
|
||||||
|
await asyncio.sleep(0.1) # allow SUB to connect
|
||||||
|
|
||||||
|
await topic.pub("hello")
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
self.assertIn("hello", received)
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
91
tests/inter_communicator/test_fmq_factory.py
Normal file
91
tests/inter_communicator/test_fmq_factory.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator.fmq import Message
|
||||||
|
from fastdeploy.inter_communicator.fmq_factory import FMQFactory as factory
|
||||||
|
|
||||||
|
|
||||||
|
class TestFMQFactory(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def test_create_queues(self):
|
||||||
|
"""Test whether all producer/consumer queues can be created."""
|
||||||
|
q1 = factory.q_a2e_producer()
|
||||||
|
q2 = factory.q_a2e_consumer()
|
||||||
|
q3 = factory.q_e2w_producer()
|
||||||
|
q4 = factory.q_e2w_consumer()
|
||||||
|
q5 = factory.q_w2e_producer()
|
||||||
|
q6 = factory.q_w2e_consumer()
|
||||||
|
q7 = factory.q_e2a_producer()
|
||||||
|
q8 = factory.q_e2a_consumer()
|
||||||
|
|
||||||
|
self.assertEqual(q1.name, "q_a2e")
|
||||||
|
self.assertEqual(q2.name, "q_a2e")
|
||||||
|
self.assertEqual(q3.name, "q_e2w")
|
||||||
|
self.assertEqual(q4.name, "q_e2w")
|
||||||
|
self.assertEqual(q5.name, "q_w2e")
|
||||||
|
self.assertEqual(q6.name, "q_w2e")
|
||||||
|
self.assertEqual(q7.name, "q_e2a")
|
||||||
|
self.assertEqual(q8.name, "q_e2a")
|
||||||
|
|
||||||
|
# 同一进程内 context 应相同
|
||||||
|
self.assertIs(q1.context, q2.context)
|
||||||
|
self.assertIs(q1.context, q3.context)
|
||||||
|
|
||||||
|
async def test_message_roundtrip(self):
|
||||||
|
"""测试 producer → consumer 消息流转"""
|
||||||
|
producer = factory.q_a2e_producer()
|
||||||
|
consumer = factory.q_a2e_consumer()
|
||||||
|
|
||||||
|
payload = {"k": "v"}
|
||||||
|
|
||||||
|
await producer.put(payload)
|
||||||
|
msg = await consumer.get(timeout=1500)
|
||||||
|
|
||||||
|
self.assertIsInstance(msg, Message)
|
||||||
|
self.assertEqual(msg.payload, payload)
|
||||||
|
|
||||||
|
async def test_multi_queue_independence(self):
|
||||||
|
"""测试多个队列互不干扰"""
|
||||||
|
|
||||||
|
prod_a2e = factory.q_a2e_producer()
|
||||||
|
cons_a2e = factory.q_a2e_consumer()
|
||||||
|
|
||||||
|
prod_e2w = factory.q_e2w_producer()
|
||||||
|
cons_e2w = factory.q_e2w_consumer()
|
||||||
|
|
||||||
|
await prod_a2e.put("msg_api")
|
||||||
|
await prod_e2w.put("msg_worker")
|
||||||
|
|
||||||
|
msg1 = await cons_a2e.get(timeout=1500)
|
||||||
|
msg2 = await cons_e2w.get(timeout=1500)
|
||||||
|
|
||||||
|
self.assertEqual(msg1.payload, "msg_api")
|
||||||
|
self.assertEqual(msg2.payload, "msg_worker")
|
||||||
|
|
||||||
|
async def test_shared_context(self):
|
||||||
|
"""验证 FMQFactory 始终返回同一个 context (单进程)"""
|
||||||
|
q1 = factory.q_a2e_producer()
|
||||||
|
q2 = factory.q_e2w_consumer()
|
||||||
|
q3 = factory.q_e2a_producer()
|
||||||
|
|
||||||
|
self.assertIs(q1.context, q2.context)
|
||||||
|
self.assertIs(q1.context, q3.context)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user