[Feature] Support mixed deployment with yiyan adapter in develop (#3976)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* fix ci

* fix for eb5

* fix ci

* fix ci

* fix ci

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-18 01:52:20 +08:00
committed by GitHub
parent 2745f37017
commit 618ccdbfba
14 changed files with 934 additions and 176 deletions

View File

@@ -82,6 +82,9 @@ jobs:
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"

View File

@@ -37,12 +37,14 @@ from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, envs, llm_logger
@@ -576,9 +578,19 @@ class EngineService:
if api_server_pid is None:
return
self.api_server_pid = api_server_pid
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
self.internal_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
)
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread.start()
@@ -592,9 +604,9 @@ class EngineService:
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.model_config.enable_mm:
err, data = self.zmq_server.receive_json_once(block)
err, data = self.recv_request_server.receive_json_once(block)
else:
err, data = self.zmq_server.receive_pyobj_once(block)
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
break
@@ -648,7 +660,7 @@ class EngineService:
)
# Since the request is not in scheduler
# Send result by zmq directly
self.zmq_server.send_multipart(request_id, [error_result])
self.send_response_server.send_response(request_id, [error_result])
except Exception as e:
llm_logger.error(
f"Error happened while receiving new request from zmq, details={e}, "
@@ -666,7 +678,7 @@ class EngineService:
time.sleep(0.005)
continue
for request_id, contents in results.items():
self.zmq_server.send_multipart(request_id, contents)
self.send_response_server.send_response(request_id, contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
@@ -766,5 +778,9 @@ class EngineService:
self.worker_healthy_live_signal.clear()
self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear()
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
@@ -115,7 +115,7 @@ class EngineClient:
"""
Create a ZMQ client.
"""
self.zmq_client = ZmqClient(model, mode)
self.zmq_client = ZmqIpcClient(model, mode)
self.zmq_client.connect()
async def format_and_add_data(self, prompts: dict):

View File

@@ -98,6 +98,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use new get_output and save_output method (0 or 1)
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
# Whether to enable model cache feature
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
# enable internal module to access LLMEngine.
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
}

View File

@@ -17,6 +17,15 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal, shared_memory_exists
from .zmq_client import ZmqClient
from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]
__all__ = [
"ZmqIpcClient",
"IPCSignal",
"EngineWorkerQueue",
"EngineCacheQueue",
"ZmqTcpServer",
"ZmqIpcServer",
"shared_memory_exists",
]

View File

@@ -14,209 +14,100 @@
# limitations under the License.
"""
import os
import threading
import time
import traceback
from abc import ABC, abstractmethod
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.utils import zmq_client_logger
from fastdeploy.utils import llm_logger
class ZmqClient:
class ZmqClientBase(ABC):
"""
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
"""
def __init__(self, name, mode):
self.context = zmq.Context(4)
self.socket = self.context.socket(mode)
self.file_name = f"/dev/shm/{name}.socket"
self.router_path = f"/dev/shm/router_{name}.ipc"
def __init__(self):
pass
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
self.mutex = threading.Lock()
self.req_dict = dict()
self.router = None
self.poller = None
self.running = True
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
@abstractmethod
def connect(self):
"""
Connect to the server using the file name specified in the constructor.
"""
self.socket.connect(f"ipc://{self.file_name}")
def start_server(self):
"""
Start the server using the file name specified in the constructor.
"""
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def create_router(self):
"""
Create a ROUTER socket and bind it to the specified router path.
"""
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
zmq_client_logger.info(f"router path: {self.router_path}")
pass
def send_json(self, data):
"""
Send a JSON-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_json(data)
def recv_json(self):
"""
Receive a JSON-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_json()
def send_pyobj(self, data):
"""
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)
def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
@abstractmethod
def close(self):
pass
def send_multipart(self, req_id, data):
"""
Send a multipart message to the router socket.
"""
if self.router is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
while self.running:
with self.mutex:
if req_id not in self.req_dict:
try:
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
if self.req_dict[req_id] == -1:
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
return
try:
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result])
zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except zmq.ZMQError as e:
zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
self.req_dict[req_id] = -1
except Exception as e:
zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
class ZmqIpcClient(ZmqClientBase):
def __init__(self, name, mode):
self.name = name
self.mode = mode
self.file_name = f"/dev/shm/{name}.socket"
self.context = zmq.Context()
self.socket = self.context.socket(self.mode)
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.context = zmq.Context()
return self.context.socket(self.mode)
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
return str(e), None
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")
def connect(self):
self._ensure_socket()
self.socket.connect(f"ipc://{self.file_name}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
Close the socket and context.
"""
if not self.running:
return
self.running = False
zmq_client_logger.info("Closing ZMQ connection...")
llm_logger.info("ZMQ client is closing connection...")
try:
if hasattr(self, "socket") and not self.socket.closed:
if self.socket is not None and not self.socket.closed:
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.close()
if self.router is not None and not self.router.closed:
self.router.close()
if not self.context.closed:
if self.context is not None:
self.context.term()
self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path)
except Exception as e:
zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
llm_logger.warning(f"ZMQ client failed to close connection - {e}")
return
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@@ -0,0 +1,335 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqServerBase(ABC):
"""
ZmqServerBase
"""
def __init__(self):
self.cached_results = defaultdict(list)
self.response_token_lock = threading.Lock()
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
def send_json(self, data):
"""
Send a JSON-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_json(data)
def recv_json(self):
"""
Receive a JSON-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_json()
def send_pyobj(self, data):
"""
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)
def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def recv_result_handle(self):
while True:
try:
with self.response_token_lock:
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
need_send_after_finished_inference = False
with self.mutex:
self.req_dict[req_id_str] = client
if req_id_str in self.cached_results:
if self.cached_results[req_id_str][-1][-1].finished:
need_send_after_finished_inference = True
if need_send_after_finished_inference:
self.send_response(req_id_str, [])
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
self.req_dict.pop(req_id_str, None)
except zmq.Again:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue
def send_response(self, req_id, data):
"""
Send generated token result to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
new_data = []
has_result_handle = False
with self.mutex:
if req_id not in self.req_dict:
self.cached_results[req_id].append(data)
else:
has_result_handle = True
if req_id in self.cached_results:
for history_data in self.cached_results[req_id]:
new_data.extend(history_data)
llm_logger.info(
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
)
del self.cached_results[req_id]
if has_result_handle:
try:
new_data.extend(data)
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(new_data)
else:
result = msgpack.packb([response.to_dict() for response in new_data])
with self.response_token_lock:
self.socket.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
)
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data and data[-1].finished:
with self.mutex:
if req_id in self.req_dict:
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
self.req_dict.pop(req_id, None)
@abstractmethod
def close(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class ZmqIpcServer(ZmqServerBase):
"""
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
"""
def __init__(self, name, mode):
self.name = name
self.mode = mode
self.cached_results = defaultdict(list)
if mode == zmq.PULL:
self.file_name = f"/dev/shm/{name}.socket"
elif mode == zmq.ROUTER:
self.file_name = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.response_token_lock = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
return self.socket
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("ZMQ server is closing connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
except Exception as e:
llm_logger.warning(f"ZMQ server failed to close connection - {e}")
return
class ZmqTcpServer(ZmqServerBase):
"""
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
"""
def __init__(self, port, mode):
self.mode = mode
self.port = port
self.cached_results = defaultdict(list)
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
self.response_token_lock = threading.Lock()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"tcp://*:{self.port}")
return self.socket
def recv_control_cmd(self):
"""
Recieve control command from client
"""
self._ensure_socket()
try:
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
task = msgpack.unpackb(task_data)
task_id_str = task["task_id"]
except zmq.Again:
return None
with self.mutex:
self.req_dict[task_id_str] = client
return task
def response_for_control_cmd(self, task_id, result):
"""
Send command result back to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created.")
try:
result = msgpack.packb(result)
self.socket.send_multipart([self.req_dict[task_id], b"", result])
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
with self.mutex:
self.req_dict.pop(task_id, None)
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
def close(self):
"""
Close the socket and context.
"""
if not self.running:
return
self.running = False
llm_logger.info("ZMQ server is closing connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
except Exception as e:
llm_logger.warning(f"ZMQ server failed to close connection - {e}")
return

View File

@@ -78,7 +78,8 @@ else:
update_inputs_v1,
)
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
@@ -160,7 +161,7 @@ def pre_process(
)
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
"""Split output_tokens and output"""
assert zmq_client is not None, "zmq_client should not be None"
output_tokens = output_tokens.reshape([-1]).numpy()
@@ -187,7 +188,7 @@ def post_process_normal(
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
zmq_client: ZmqIpcClient = None,
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
@@ -389,7 +390,7 @@ def post_process(
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
zmq_client: ZmqIpcClient = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:

View File

@@ -29,7 +29,7 @@ import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcServer
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger
@@ -58,12 +58,11 @@ class TokenProcessor:
self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
self.zmq_server = ZmqClient(
self.zmq_server = ZmqIpcServer(
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
)
self.zmq_server.start_server()
self.zmq_server.create_router()
self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.model_config.enable_logprob
@@ -498,6 +497,7 @@ class TokenProcessor:
metrics = RequestMetrics(
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
model_execute_time=time.time() - task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
@@ -510,6 +510,7 @@ class TokenProcessor:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
model_execute_time=time.time() - task.inference_start_time,
)
self.number_of_output_tokens += len(token_ids)
self._record_metrics(task, current_time, token_ids)

View File

@@ -208,6 +208,9 @@ class LocalScheduler:
"""
return (token_num + block_size - 1) // block_size
def get_unhandled_request_num(self):
return len(self.ids) - self.ids_read_cursor
def get_requests(
self,
available_blocks,

View File

@@ -0,0 +1,118 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
import traceback
# **Note**: Just for internal use
import zmq
from fastdeploy.inter_communicator import ZmqTcpServer
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
from fastdeploy.utils import envs, get_logger
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
class InternalAdapter:
def __init__(self, cfg, engine, dp_rank):
self.cfg = cfg
self.engine = engine
self.dp_rank = dp_rank
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
self.recv_external_instruct_thread = threading.Thread(
target=self._recv_external_module_control_instruct, daemon=True
)
self.recv_external_instruct_thread.start()
if cfg.splitwise_role != "mixed":
self.response_external_instruct_thread = threading.Thread(
target=self._response_external_module_control_instruct, daemon=True
)
self.response_external_instruct_thread.start()
def _get_current_server_info(self):
"""
Get resources information
"""
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
available_block_num = self.engine.resource_manager.available_block_num()
server_info = {
"splitwise_role": self.cfg.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size),
"block_num": int(available_block_num),
"max_block_num": int(self.cfg.cache_config.total_block_num),
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
"max_batch_size": int(available_batch_size),
"max_input_token_num": self.cfg.max_model_len,
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
"available_batch": int(self.engine.resource_manager.available_batch()),
}
return server_info
def _recv_external_module_control_instruct(self):
"""
Receive a multipart message from the control cmd socket.
"""
while True:
try:
with self.response_lock:
task = self.recv_control_cmd_server.recv_control_cmd()
if task is None:
time.sleep(0.001)
continue
logger.info(f"Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()
result = {"task_id": task_id_str, "result": payload_info}
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "get_metrics":
metrics_text = get_filtered_metrics(
[],
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
)
result = {"task_id": task_id_str, "result": metrics_text}
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":
self.engine.engine_worker_queue.put_connect_rdma_task(task)
except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
def _response_external_module_control_instruct(self):
while True:
try:
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
if result_data:
task_id_str = result_data["task_id"]
result = {"task_id": task_id_str, "result": result_data}
logger.info(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
else:
time.sleep(0.001)
except Exception as e:
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")

View File

@@ -75,7 +75,7 @@ import zmq
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -171,7 +171,7 @@ class GPUModelRunner(ModelRunnerBase):
self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
logger.info(f"zmq client get_save_output_rank{local_rank}")
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000

View File

@@ -0,0 +1,251 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import queue
import shutil
import signal
import socket
import subprocess
import sys
import time
import pytest
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
print("project_root", project_root)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient
env = os.environ.copy()
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8234))
FD_ENABLE_INTERNAL_ADAPTER = int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "1"))
FD_ZMQ_RECV_REQUEST_SERVER_PORT = int(os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8204"))
FD_ZMQ_SEND_RESPONSE_SERVER_PORT = int(os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8205"))
FD_ZMQ_CONTROL_CMD_SERVER_PORTS = int(os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8206"))
FD_ZMQ_CONTROL_CMD_SERVER_PORT = FD_ZMQ_CONTROL_CMD_SERVER_PORTS
env["FD_ENABLE_INTERNAL_ADAPTER"] = str(FD_ENABLE_INTERNAL_ADAPTER)
env["FD_ZMQ_RECV_REQUEST_SERVER_PORT"] = str(FD_ZMQ_RECV_REQUEST_SERVER_PORT)
env["FD_ZMQ_SEND_RESPONSE_SERVER_PORT"] = str(FD_ZMQ_SEND_RESPONSE_SERVER_PORT)
env["FD_ZMQ_CONTROL_CMD_SERVER_PORTS"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORTS)
env["FD_ZMQ_CONTROL_CMD_SERVER_PORT"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORT)
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_ZMQ_RECV_REQUEST_SERVER_PORT,
FD_ZMQ_SEND_RESPONSE_SERVER_PORT,
FD_ZMQ_CONTROL_CMD_SERVER_PORT,
]
@pytest.fixture
def zmq_req_client():
client = LLMReqClient("0.0.0.0", FD_ZMQ_RECV_REQUEST_SERVER_PORT, FD_ZMQ_SEND_RESPONSE_SERVER_PORT)
return client
@pytest.fixture
def zmq_control_client():
client = LLMControlClient("0.0.0.0", FD_ZMQ_CONTROL_CMD_SERVER_PORT)
return client
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
for pid in output.splitlines():
os.kill(int(pid), signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
try:
result = subprocess.run(
f"ps -ef -ww| grep {FD_CACHE_QUEUE_PORT} | grep -v grep", shell=True, capture_output=True, text=True
)
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split()
pid = int(parts[1]) # ps -ef 的第二列是 PID
print(f"Killing PID: {pid}")
os.kill(pid, signal.SIGKILL)
except Exception as e:
print(f"Failed to kill cache manager process: {e}")
def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean_ports()
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle")
else:
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
log_path = "server.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"32768",
"--max-num-seqs",
"128",
"--quantization",
"wint4",
]
# Start subprocess in new process group
# 清除log目录
if os.path.exists("log"):
shutil.rmtree("log")
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
env=env,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
)
# Wait up to 300 seconds for API server to be ready
for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"API server is up on port {FD_API_PORT}")
break
time.sleep(1)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process.pid, signal.SIGTERM)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process.pid, signal.SIGTERM)
clean_ports()
print(f"API server (pid={process.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
def test_request_and_response(zmq_req_client):
prompt_token_ids = [5300, 93956, 55791]
req_id = "test"
request = {
"req_id": req_id,
"request_id": req_id,
"min_tokens": 1,
"dp_rank": 0, # P实例 DP rank, 从当前环境变量里读取
"prompt_token_ids": prompt_token_ids,
"prompt_token_ids_len": len(prompt_token_ids),
"eos_token_ids": [2],
"stop_token_ids": [2],
"max_dec_len": 32 * 1024,
"max_tokens": 32 * 1024,
"min_dec_len": 1,
"arrival_time": time.time(),
"preprocess_start_time": time.time(),
"preprocess_end_time": time.time(),
"messages": [],
"temperature": 0.8,
"penalty_score": 1.0,
"repetition_penalty": 1.0,
"presence_penalty": 0,
"top_p": 0.8,
"frequency_penalty": 0.0,
}
result_queue = queue.Queue()
zmq_req_client.start(result_queue)
zmq_req_client.send_request(request)
zmq_req_client.request_result(req_id)
has_is_end_result = False
while True:
result = result_queue.get()
if result[-1]["finished"]:
has_is_end_result = True
break
assert has_is_end_result is True
def test_control_cmd(zmq_control_client):
result = zmq_control_client.get_payload()
assert "unhandled_request_num" in result
result = zmq_control_client.get_metrics()
assert result is not None

View File

@@ -0,0 +1,121 @@
import threading
import time
import uuid
from threading import Event
import msgpack
import zmq
class LLMReqClient:
"""
LLM request client
"""
def __init__(self, ip, send_req_server_port, recv_res_server_port):
self.ZMQ_SNDHWM = 64 * 1024
self.context = zmq.Context()
self.send_req_client = self.context.socket(zmq.PUSH)
self.recv_res_client = self.context.socket(zmq.DEALER)
self.send_req_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.send_req_client.setsockopt(zmq.SNDTIMEO, -1)
self.recv_res_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.recv_res_client.setsockopt(zmq.SNDTIMEO, -1)
self.send_req_client.connect(f"tcp://{ip}:{send_req_server_port}")
self.recv_res_client.connect(f"tcp://{ip}:{recv_res_server_port}")
self.need_exit = False
self.response_socket_lock = threading.Lock()
def send_request(self, req_data):
self.send_req_client.send_json(req_data)
def request_result(self, req_id):
with self.response_socket_lock:
print(f"request result data for {req_id}")
self.recv_res_client.send_multipart([b"", req_id.encode("utf-8")])
def consume_results(self, result_queue):
while True:
try:
try:
with self.response_socket_lock:
frames = self.recv_res_client.recv_multipart(flags=zmq.NOBLOCK)
except zmq.Again:
time.sleep(0.001)
continue
data = frames[-1]
response = msgpack.unpackb(data)
print(f"get result data {response}")
result_queue.put(response)
if self.need_exit:
break
except Exception as e:
print(f"zmq client occured error {e} type: {type(e)} frames: {frames}")
def start(self, result_queue):
threading.Thread(target=self.consume_results, args=(result_queue,), daemon=True).start()
def exit(self):
print("exit")
self.need_exit = True
class LLMControlClient:
"""
LLM control client
"""
def __init__(self, ip, port):
self.ZMQ_SNDHWM = 64 * 1024
self.context = zmq.Context()
self.control_client = self.context.socket(zmq.DEALER)
self.control_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.control_client.setsockopt(zmq.SNDTIMEO, -1)
self.control_client.connect(f"tcp://{ip}:{port}")
self.task_event = {}
self.result = {}
self.response_socket_lock = threading.Lock()
threading.Thread(target=self.recv_results, daemon=True).start()
def get_payload(self):
task_id = f"get_payload_{uuid.uuid4()}"
task = {"task_id": task_id, "cmd": "get_payload"}
self.task_event[task_id] = Event()
payload = msgpack.packb(task)
with self.response_socket_lock:
self.control_client.send_multipart([b"", payload])
self.task_event[task_id].wait()
result = self.result[task_id]
del self.result[task_id]
del self.task_event[task_id]
return result
def get_metrics(self):
task_id = f"get_metrics_{uuid.uuid4()}"
task = {"task_id": task_id, "cmd": "get_metrics"}
self.task_event[task_id] = Event()
payload = msgpack.packb(task)
with self.response_socket_lock:
self.control_client.send_multipart([b"", payload])
self.task_event[task_id].wait()
result = self.result[task_id]
del self.result[task_id]
del self.task_event[task_id]
return result
def recv_results(self):
while True:
try:
try:
with self.response_socket_lock:
frames = self.control_client.recv_multipart(flags=zmq.NOBLOCK)
except zmq.Again:
time.sleep(0.001)
continue
data = frames[-1]
result = msgpack.unpackb(data)
task_id = result["task_id"]
self.result[task_id] = result["result"]
self.task_event[task_id].set()
except Exception as e:
print(f"zmq client occured error {e} type: {type(e)} frames: {frames}")