Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,25 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .zmq_client import ZmqClient
from .ipc_signal import IPCSignal
from .engine_worker_queue import EngineWorkerQueue
from .engine_cache_queue import EngineCacheQueue
__all__ = [
'ZmqClient', 'IPCSignal', 'EngineWorkerQueue', 'CacheQueueManager'
]

View File

@@ -0,0 +1,310 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
from multiprocessing.managers import (AcquirerProxy, BaseManager, ListProxy,
Value, ValueProxy)
from typing import Any, List, Tuple
from fastdeploy.utils import get_logger
logger = get_logger("cache_queue_manager", "cache_queue_manager.log")
class EngineCacheQueue:
"""
Multiprocessing manager for cache queue between Engine and Worker.
Manages shared resources using multiprocessing managers for inter-process communication.
"""
def __init__(
self,
address: Tuple[str, int] = ('127.0.0.1', 56666),
authkey: bytes = b'cache_queue_service',
is_server: bool = False,
num_client: int = 1, # tensor parallel size
client_id: int = -1, # tensor parallel id
local_data_parallel_size: int = 1, # data parallel size
local_data_parallel_id: int = 0, # local data parallel id
) -> None:
"""
Initialize the cache communication queue.
Args:
address: Network address (IP, port) for the queue server
authkey: Authentication key for secure connection
is_server: Whether this instance acts as a server
num_client: Total number of expected clients
client_id: Unique identifier for client instances
local_data_parallel_size: data parallel size
local_data_parallel_id: local data parallel id
"""
self.address: Tuple[str, int] = address
self.authkey: bytes = authkey
self.num_client: int = num_client
self.client_id: int = client_id
self.local_data_parallel_size = local_data_parallel_size
self.local_data_parallel_id = local_data_parallel_id
class QueueManager(BaseManager):
"""
Custom QueueManager for proxy object registration
"""
pass
if is_server:
# Server-side initialization for shared resources
self.transfer_task_queue_init: List[List[Any]] = [
list() for _ in range(self.local_data_parallel_size)
]
self.tansfer_done_queue_init: List[List[Any]] = [
list() for _ in range(self.local_data_parallel_size)
]
self.cache_sync_value_init: List[Value] = [
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.transfer_task_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.transfer_task_done_lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
# Initialize barriers
self.barrier1_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.barrier2_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.barrier3_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.swap_to_cpu_barrier1_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.swap_to_cpu_barrier2_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.swap_to_gpu_barrier1_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
self.swap_to_gpu_barrier2_init = [
threading.Barrier(self.num_client)
for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register(
"get_transfer_task_queue",
callable=lambda idx: self.transfer_task_queue_init[idx],
proxytype=ListProxy)
QueueManager.register(
"get_tansfer_done_queue",
callable=lambda idx: self.tansfer_done_queue_init[idx],
proxytype=ListProxy)
QueueManager.register(
"get_cache_sync_value",
callable=lambda idx: self.cache_sync_value_init[idx],
proxytype=ValueProxy)
QueueManager.register(
"get_transfer_task_lock",
callable=lambda idx: self.transfer_task_lock_init[idx],
proxytype=AcquirerProxy)
QueueManager.register(
"get_transfer_task_done_lock",
callable=lambda idx: self.transfer_task_done_lock_init[idx],
proxytype=AcquirerProxy)
QueueManager.register("get_barrier1",
callable=lambda idx: self.barrier1_init[idx])
QueueManager.register("get_barrier2",
callable=lambda idx: self.barrier2_init[idx])
QueueManager.register("get_barrier3",
callable=lambda idx: self.barrier3_init[idx])
QueueManager.register(
"get_swap_to_cpu_barrier1",
callable=lambda idx: self.swap_to_cpu_barrier1_init[idx])
QueueManager.register(
"get_swap_to_cpu_barrier2",
callable=lambda idx: self.swap_to_cpu_barrier2_init[idx])
QueueManager.register(
"get_swap_to_gpu_barrier1",
callable=lambda idx: self.swap_to_gpu_barrier1_init[idx])
QueueManager.register(
"get_swap_to_gpu_barrier2",
callable=lambda idx: self.swap_to_gpu_barrier2_init[idx])
self.manager: BaseManager = QueueManager(address=self.address,
authkey=self.authkey)
self.manager.start()
logger.info(f"EngineCacheQueue server started at {self.address}")
else:
# Client-side connection setup
assert 0 <= self.client_id < self.num_client, (
f"client_id must be between 0 and {self.num_client-1}, got {self.client_id}"
)
QueueManager.register("get_transfer_task_queue")
QueueManager.register("get_tansfer_done_queue")
QueueManager.register("get_cache_sync_value")
QueueManager.register("get_transfer_task_lock")
QueueManager.register("get_transfer_task_done_lock")
QueueManager.register("get_barrier1")
QueueManager.register("get_barrier2")
QueueManager.register("get_barrier3")
QueueManager.register("get_swap_to_cpu_barrier1")
QueueManager.register("get_swap_to_cpu_barrier2")
QueueManager.register("get_swap_to_gpu_barrier1")
QueueManager.register("get_swap_to_gpu_barrier2")
self.manager = QueueManager(address=self.address,
authkey=self.authkey)
self._connect_with_retry()
# Get proxy objects for shared resources
self.transfer_task_queue = self.manager.get_transfer_task_queue(
self.local_data_parallel_id)
self.tansfer_done_queue = self.manager.get_tansfer_done_queue(
self.local_data_parallel_id)
self.task_sync_value = self.manager.get_cache_sync_value(
self.local_data_parallel_id)
self.task_lock = self.manager.get_transfer_task_lock(
self.local_data_parallel_id)
self.task_done_lock = self.manager.get_transfer_task_done_lock(
self.local_data_parallel_id)
# Get barrier proxies
self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id)
self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id)
self.barrier3 = self.manager.get_barrier3(self.local_data_parallel_id)
self.swap_to_cpu_barrier1 = self.manager.get_swap_to_cpu_barrier1(
self.local_data_parallel_id)
self.swap_to_cpu_barrier2 = self.manager.get_swap_to_cpu_barrier2(
self.local_data_parallel_id)
self.swap_to_gpu_barrier1 = self.manager.get_swap_to_gpu_barrier1(
self.local_data_parallel_id)
self.swap_to_gpu_barrier2 = self.manager.get_swap_to_gpu_barrier2(
self.local_data_parallel_id)
self.total_num: int = (1 << self.num_client) - 1
if not is_server:
# Setup position and total_num for sync operations
self.position: int = 1 << self.client_id
logger.info(
f"Connected EngineCacheQueue client_id: {self.client_id}")
def _connect_with_retry(self,
max_retries: int = 5,
interval: int = 3) -> None:
"""
Connect to the server with retry mechanism.
Args:
max_retries: Maximum connection attempts
interval: Retry interval in seconds
Raises:
ConnectionError: If all connection attempts fail
"""
for _ in range(max_retries):
try:
self.manager.connect()
return
except ConnectionRefusedError:
time.sleep(interval)
raise ConnectionError(
f"EngineCacheQueue cannot connect to {self.address}")
def put_transfer_task(self, item):
"""
put swap task
"""
self.task_lock.acquire()
if 0 < self.task_sync_value.get() < self.total_num:
self.task_lock.release()
while 0 < self.task_sync_value.get() < self.total_num:
time.sleep(0.001)
self.task_lock.acquire()
self.task_sync_value.set(0)
self.transfer_task_queue.append(item)
logger.info(
f"put_transfer_task: put swap task {item[-1]} to queue successful")
self.task_lock.release()
def get_transfer_task(self):
"""
get swap task
"""
data = None
read_finish = False
self.task_lock.acquire()
if (self.task_sync_value.get() & self.position == 0
and len(self.transfer_task_queue) > 0):
data = self.transfer_task_queue[0]
logger.debug(
f"get_transfer_task: Get {data} by {self.client_id} from queue successful"
)
set_value = self.task_sync_value.get() | self.position
logger.info("get_transfer_task: rank: {0} set_value: {1}".format(
self.client_id, set_value))
if set_value >= self.total_num:
self.transfer_task_queue.pop(0)
set_value = 0
read_finish = True
self.task_sync_value.set(set_value)
self.task_lock.release()
return data, read_finish
def put_transfer_done_signal(self, item):
"""
put swap result
"""
self.task_done_lock.acquire()
self.tansfer_done_queue.append(item)
self.task_done_lock.release()
logger.info(
f"put_transfer_done_signal: put swap task {item[-1]} finished signal to queue successful"
)
def get_transfer_done_signal(self):
"""
get swap result
"""
data = None
self.task_done_lock.acquire()
if len(self.tansfer_done_queue) > 0:
data = self.tansfer_done_queue.pop(0)
logger.info(
f"get_transfer_done_signal: Get swap task {data[-1]} finished signal from queue successful"
)
self.task_done_lock.release()
return data
def empty(self):
"""
check if queue is empty
"""
try:
return len(self.transfer_task_queue) == 0
except Exception as e:
logger.error(f"empty function meets error: {e}")
raise e

View File

@@ -0,0 +1,416 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
from multiprocessing.managers import (AcquirerProxy, BaseManager, ListProxy,
Value, ValueProxy)
from queue import Queue
from typing import Any, List, Tuple
import numpy as np
from fastdeploy.utils import llm_logger
class EngineWorkerQueue:
"""
Cross-machine and cross-process communication queue between Engine and Worker.
Manages shared resources using multiprocessing managers for inter-process communication.
"""
def __init__(
self,
address: Tuple[str, int] = ('0.0.0.0', 5000),
authkey: bytes = b'secret_key',
is_server: bool = False,
num_client: int = 1, # tensor parallel size
client_id: int = -1, # tensor parallel id
local_data_parallel_size: int = 1, # data parallel size
local_data_parallel_id: int = 0, # local data parallel id
) -> None:
"""
Initialize the communication queue.
Args:
address: Network address (IP, port) for the queue server
authkey: Authentication key for secure connection
is_server: Whether this instance acts as a server
num_client: Total number of expected clients
client_id: Unique identifier for client instances
"""
self.address: Tuple[str, int] = address
self.authkey: bytes = authkey
self.is_server: bool = is_server
self.num_client: int = num_client
self.client_id: int = client_id
self.local_data_parallel_size = local_data_parallel_size
self.local_data_parallel_id = local_data_parallel_id
class QueueManager(BaseManager):
"""
Custom QueueManager for proxy object registration.
"""
pass
if is_server:
# Server-side initialization for shared resources
self.tasks_init: List[List[Any]] = [
list() for _ in range(self.local_data_parallel_size)
]
self.client_read_flag_init: List[List[int]] = [
[1] * self.num_client
for _ in range(self.local_data_parallel_size)
]
self.lock_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.read_finish_flag_init: List[Value] = [
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.connected_client_counter_init: List[Value] = [
Value("i", 0) for _ in range(self.local_data_parallel_size)
]
self.finished_req_queue = [
Queue() for _ in range(self.local_data_parallel_size)
]
self.cache_infos_init: List[List[Any]] = [
list() for _ in range(self.local_data_parallel_size)
]
self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client
for _ in range(self.local_data_parallel_size)
]
self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register("get_tasks",
callable=lambda idx: self.tasks_init[idx],
proxytype=ListProxy)
QueueManager.register(
"get_client_read_flag",
callable=lambda idx: self.client_read_flag_init[idx],
proxytype=ListProxy)
QueueManager.register("get_lock",
callable=lambda idx: self.lock_init[idx],
proxytype=AcquirerProxy)
QueueManager.register(
"get_read_finish_flag",
callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy)
QueueManager.register(
"get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx],
proxytype=ValueProxy)
QueueManager.register(
'get_finish_request_queue',
callable=lambda idx: self.finished_req_queue[idx])
QueueManager.register(
"get_cache_infos",
callable=lambda idx: self.cache_infos_init[idx],
proxytype=ListProxy)
QueueManager.register(
"get_client_read_info_flag",
callable=lambda idx: self.client_read_info_flag_init[idx],
proxytype=ListProxy)
QueueManager.register(
"get_lock_info",
callable=lambda idx: self.lock_info_init[idx],
proxytype=AcquirerProxy)
self.disaggregate_requests = [
Queue() for _ in range(self.local_data_parallel_size)
]
QueueManager.register(
"get_disaggregate_requests",
callable=lambda idx: self.disaggregate_requests[idx])
self.available_prefill_instances = Queue()
QueueManager.register(
"get_available_prefill_instances",
callable=lambda: self.available_prefill_instances)
QueueManager.register(
"get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx])
self.manager: BaseManager = QueueManager(address=self.address,
authkey=self.authkey)
self.manager.start()
else:
# Client-side connection setup
assert self.client_id >= 0 and self.client_id < self.num_client, (
f"self.client_id={self.client_id}, self.num_client={self.num_client}"
)
QueueManager.register("get_tasks")
QueueManager.register("get_client_read_flag")
QueueManager.register("get_lock")
QueueManager.register("get_read_finish_flag")
QueueManager.register("get_connected_client_counter")
QueueManager.register("get_finish_request_queue")
QueueManager.register("get_cache_infos")
QueueManager.register("get_client_read_info_flag")
QueueManager.register("get_lock_info")
QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier")
self.manager = QueueManager(address=self.address,
authkey=self.authkey)
self._connect_with_retry()
# Get proxy objects for shared resources
self.tasks: ListProxy = self.manager.get_tasks(
self.local_data_parallel_id)
self.client_read_flag: ListProxy = self.manager.get_client_read_flag(
self.local_data_parallel_id)
self.lock: AcquirerProxy = self.manager.get_lock(
self.local_data_parallel_id)
self.read_finish_flag: ValueProxy = self.manager.get_read_finish_flag(
self.local_data_parallel_id)
self.connected_client_counter: ValueProxy = \
self.manager.get_connected_client_counter(self.local_data_parallel_id)
self.cache_infos: ListProxy = self.manager.get_cache_infos(
self.local_data_parallel_id)
self.client_read_info_flag: ListProxy = self.manager.get_client_read_info_flag(
self.local_data_parallel_id)
self.lock_info: AcquirerProxy = self.manager.get_lock_info(
self.local_data_parallel_id)
# p/d 分离获取
self.disaggregate_requests = self.manager.get_disaggregate_requests(
self.local_data_parallel_id)
self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(
self.local_data_parallel_id
)
self.finished_req_queue = self.manager.get_finish_request_queue(
self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag)
if is_server:
llm_logger.info("EngineWorkerQueue server started.")
else:
# Update client connection counter
self.lock.acquire()
self.connected_client_counter.set(
self.connected_client_counter.get() + 1)
self.lock.release()
llm_logger.info((
f"Connected EngineWorkerQueue client_id: {self.client_id}, number "
f"of connected clients: {self.connected_client_counter.get()}"
))
def _connect_with_retry(self,
max_retries: int = 5,
interval: int = 3) -> None:
"""
Connect to the server with retry mechanism.
Args:
max_retries: Maximum connection attempts
interval: Retry interval in seconds
Raises:
ConnectionError: If all connection attempts fail
"""
for _ in range(max_retries):
try:
self.manager.connect()
return
except ConnectionRefusedError:
time.sleep(interval)
raise ConnectionError(f"TaskQueue cannot connect {self.address}")
def put_tasks(self, tasks: List[Any]) -> None:
"""
Add tasks to the shared queue in a thread-safe manner.
Waits until all clients have read previous tasks before adding new ones.
Args:
tasks: Tasks to be added to the queue
"""
self.lock.acquire()
while sum(self.client_read_flag) < self.num_client:
self.lock.release()
time.sleep(0.001)
self.lock.acquire()
self.tasks[:] = list()
self.client_read_flag[:] = [0] * self.num_client
self.tasks.append(tasks)
self.lock.release()
def get_tasks(self) -> Tuple[List[Any], bool]:
"""
Retrieve tasks from the shared queue and update read status.
Returns:
tuple: (list of tasks, bool indicating if all clients have read)
"""
tasks: List[Any] = list()
self.lock.acquire()
tasks.extend(self.tasks)
self.client_read_flag[self.client_id] = 1
all_client_read: bool = np.sum(
self.client_read_flag) == self.num_client
if all_client_read:
self.tasks[:] = list()
self.lock.release()
return tasks, all_client_read
def num_tasks(self) -> int:
"""
Get current number of tasks in the queue.
Returns:
int: Total number of tasks
"""
self.lock.acquire()
total_num: int = len(self.tasks)
self.lock.release()
return total_num
def get_prefill_instances(self):
"""
check if the prefill queue is empty
"""
if self.available_prefill_instances.qsize() == 0:
return 0
else:
return self.available_prefill_instances.get()
def put_cache_info(self, cache_info) -> None:
"""
Args:
tasks: Tasks to be added to the queue
"""
self.lock_info.acquire()
while sum(self.client_read_info_flag) < self.num_client:
self.lock_info.release()
time.sleep(0.001)
self.lock_info.acquire()
self.cache_infos[:] = list()
self.client_read_info_flag[:] = [0] * self.num_client
self.cache_infos.extend(cache_info)
llm_logger.debug(
f"cache_infos: {self.cache_infos} local_data_parallel_id:{self.local_data_parallel_id}"
)
self.lock_info.release()
def get_cache_info(self) -> List[Any]:
"""
Retrieve tasks from the shared queue and update read status.
Returns:
tuple: (list of tasks, bool indicating if all clients have read)
"""
cache_infos: List[Any] = list()
self.lock_info.acquire()
if self.client_read_info_flag[self.client_id] == 1:
self.lock_info.release()
return cache_infos
cache_infos.extend(self.cache_infos)
self.client_read_info_flag[self.client_id] = 1
all_client_read: bool = np.sum(
self.client_read_info_flag) == self.num_client
if all_client_read:
self.cache_infos[:] = list()
self.lock_info.release()
if len(cache_infos) != 0:
llm_logger.debug(
f"get cache infos: {cache_infos} local_data_parallel_id:{self.local_data_parallel_id}"
)
return cache_infos
def num_cache_infos(self) -> int:
"""
Get current number of tasks in the queue.
Returns:
int: Total number of tasks
"""
self.lock_info.acquire()
total_num: int = len(self.cache_infos)
self.lock_info.release()
return total_num
def put_finished_req(self, req_ids) -> None:
"""
Put finished request ID into the queue.
Args:
req_ids: Request ID to be added to the queue
"""
self.finished_req_queue.put(req_ids)
def get_finished_req(self) -> str:
"""
Get finished request ID from the queue.
Returns:
str: Finished request ID
"""
ans = []
if self.finished_req_queue.empty():
return ans
ans = self.finished_req_queue.get()
llm_logger.debug(f"get finished req: {ans}")
return ans
def disaggregate_queue_empty(self):
"""
Check if the disaggregated task queue is empty.
"""
return self.disaggregate_requests.qsize() == 0
def put_disaggregated_tasks(self, item):
"""
put disaggregated tasks to the queue
"""
llm_logger.debug("put item to queue")
self.disaggregate_requests.put(item)
llm_logger.debug("put item to queue success")
def get_disaggregated_tasks(self):
"""
get disaggregated tasks from the queue
"""
llm_logger.debug("get tasks from queue")
if self.disaggregate_requests.qsize() == 0:
return None
item = []
while not self.disaggregate_requests.empty():
item.append(self.disaggregate_requests.get())
llm_logger.debug("get tasks from queue success")
return item
def cleanup(self):
"""
Exit the worker queue gracefully.
"""
if self.manager is not None and self.is_server:
self.manager.shutdown()

View File

@@ -0,0 +1,96 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import numpy as np
from multiprocessing.shared_memory import SharedMemory
def shared_memory_exists(name: str) -> bool:
"""Check if a shared memory block with the given name exists.
Args:
name: The unique identifier of the shared memory block.
Returns:
True if the shared memory exists, False otherwise.
"""
try:
shm = SharedMemory(name=name, create=False)
shm.close()
return True
except FileNotFoundError:
return False
except Exception as e:
print(f"Unexpected error: {e}")
return False
class IPCSignal:
"""A shared memory wrapper for inter-process communication using numpy arrays.
Allows creating or connecting to existing shared memory blocks and synchronizing
numpy array data between processes.
Attributes:
shm: The underlying SharedMemory object.
value: Numpy array interface to the shared memory buffer.
"""
def __init__(self,
name: str,
array: np.ndarray,
dtype: np.dtype,
suffix: int = None,
create: bool = True) -> None:
"""Initialize or connect to a shared memory block.
Args:
name: Unique identifier for the shared memory block.
array: Numpy array template defining shape and data type.
dtype: Data type of the array (must match array.dtype).
suffix: Suffix number that will be appended to the name.
create: If True, creates new memory block; otherwise connects to existing.
Raises:
AssertionError: If create=True but memory already exists, or dtype mismatch.
"""
assert isinstance(array, np.ndarray), "Input must be a numpy array"
assert dtype == array.dtype, "Specified dtype must match array dtype"
# Set a suffix for name to avoid name conflict while there are multiple engine launched
if suffix is not None:
name = name + f".{suffix}"
if create:
assert not shared_memory_exists(
name), f"ShareMemory: {name} already exists"
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
self.value: np.ndarray = np.ndarray(array.shape,
dtype=array.dtype,
buffer=self.shm.buf)
self.value[:] = array # Initialize with input array data
else:
self.shm = SharedMemory(name=name)
self.value: np.ndarray = np.ndarray(array.shape,
dtype=array.dtype,
buffer=self.shm.buf)
def clear(self) -> None:
"""Release system resources and unlink the shared memory block."""
if shared_memory_exists(self.shm.name):
self.shm.close()
self.shm.unlink()

View File

@@ -0,0 +1,196 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import threading
import time
import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqClient:
"""
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
"""
def __init__(self, name, mode):
self.context = zmq.Context()
self.socket = self.context.socket(mode)
self.file_name = f"/dev/shm/{name}.socket"
self.router_path = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.mutex = threading.Lock()
self.req_dict = dict()
self.router = None
self.poller = None
self.running = True
def connect(self):
"""
Connect to the server using the file name specified in the constructor.
"""
self.socket.connect(f"ipc://{self.file_name}")
def start_server(self):
"""
Start the server using the file name specified in the constructor.
"""
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def create_router(self):
"""
Create a ROUTER socket and bind it to the specified router path.
"""
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
def send_json(self, data):
"""
Send a JSON-serializable object over the socket.
"""
self.socket.send_json(data)
def recv_json(self):
"""
Receive a JSON-serializable object from the socket.
"""
return self.socket.recv_json()
def send_pyobj(self, data):
"""
Send a Pickle-serializable object over the socket.
"""
self.socket.send_pyobj(data)
def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
return self.socket.recv_pyobj()
def send_multipart(self, req_id, data):
"""
Send a multipart message to the router socket.
"""
if self.router is None:
raise RuntimeError(
"Router socket not created. Call create_router() first.")
while self.running:
with self.mutex:
if req_id not in self.req_dict:
try:
client, _, request_id = self.router.recv_multipart(
flags=zmq.NOBLOCK)
req_id_str = request_id.decode('utf-8')
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
try:
result = json.dumps(data.to_dict()).encode('utf-8')
self.router.send_multipart([self.req_dict[req_id], b'', result])
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data.finished:
with self.mutex:
self.req_dict.pop(data.request_id, None)
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if hasattr(self, 'socket') and not self.socket.closed:
self.socket.close()
if self.router is not None and not self.router.closed:
self.router.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path)
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()