diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml index ba36f6834..06359938e 100644 --- a/.github/workflows/_pre_ce_test.yml +++ b/.github/workflows/_pre_ce_test.yml @@ -82,6 +82,9 @@ jobs: FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100)) FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100)) FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100)) + FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100)) + FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100)) + FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100)) echo "Test ENV Parameter:" echo "=========================================================" echo "FLASK_PORT=${FLASK_PORT}" diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 7b9b0bdc6..2f8864a19 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -37,12 +37,14 @@ from fastdeploy.inter_communicator import ( EngineCacheQueue, EngineWorkerQueue, IPCSignal, - ZmqClient, + ZmqIpcServer, + ZmqTcpServer, ) from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.plugins.token_processor import load_token_processor_plugins +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, envs, llm_logger @@ -576,9 +578,19 @@ class EngineService: if api_server_pid is None: return self.api_server_pid = api_server_pid - self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) - self.zmq_server.start_server() - self.zmq_server.create_router() + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) + self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) + self.internal_adapter = InternalAdapter( + cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node + ) + else: + self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) + self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) + self.recv_result_handle_thread = threading.Thread( + target=self.send_response_server.recv_result_handle, daemon=True + ) + self.recv_result_handle_thread.start() time.sleep(3) self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True) self.insert_task_to_scheduler_thread.start() @@ -592,9 +604,9 @@ class EngineService: try: block = True if len(added_requests) == 0 else False if not self.cfg.model_config.enable_mm: - err, data = self.zmq_server.receive_json_once(block) + err, data = self.recv_request_server.receive_json_once(block) else: - err, data = self.zmq_server.receive_pyobj_once(block) + err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") break @@ -648,7 +660,7 @@ class EngineService: ) # Since the request is not in scheduler # Send result by zmq directly - self.zmq_server.send_multipart(request_id, [error_result]) + self.send_response_server.send_response(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happened while receiving new request from zmq, details={e}, " @@ -666,7 +678,7 @@ class EngineService: time.sleep(0.005) continue for request_id, contents in results.items(): - self.zmq_server.send_multipart(request_id, contents) + self.send_response_server.send_response(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}") @@ -766,5 +778,9 @@ class EngineService: self.worker_healthy_live_signal.clear() self.exist_prefill_task_signal.clear() self.model_weights_status_signal.clear() - if hasattr(self, "zmq_server") and self.zmq_server is not None: - self.zmq_server.close() + if hasattr(self, "send_response_server") and self.send_response_server is not None: + self.send_response_server.close() + if hasattr(self, "recv_request_server") and self.recv_request_server is not None: + self.recv_request_server.close() + if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: + self.recv_control_cmd_server.close() diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index b6c0008c3..de6eb9e91 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -27,7 +27,7 @@ from fastdeploy.config import ModelConfig from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal, ZmqClient +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform @@ -115,7 +115,7 @@ class EngineClient: """ Create a ZMQ client. """ - self.zmq_client = ZmqClient(model, mode) + self.zmq_client = ZmqIpcClient(model, mode) self.zmq_client.connect() async def format_and_add_data(self, prompts: dict): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index eaac558ee..06a919ab6 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -98,6 +98,15 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use new get_output and save_output method (0 or 1) "FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))), # Whether to enable model cache feature + "FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))), + # enable internal module to access LLMEngine. + "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), + # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), + # LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), + # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), } diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 41eb1ccc2..373702edb 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -17,6 +17,15 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue from .ipc_signal import IPCSignal, shared_memory_exists -from .zmq_client import ZmqClient +from .zmq_client import ZmqIpcClient +from .zmq_server import ZmqIpcServer, ZmqTcpServer -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"] +__all__ = [ + "ZmqIpcClient", + "IPCSignal", + "EngineWorkerQueue", + "EngineCacheQueue", + "ZmqTcpServer", + "ZmqIpcServer", + "shared_memory_exists", +] diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 7ef78c37e..ac9ba4bfe 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -14,209 +14,100 @@ # limitations under the License. """ -import os -import threading -import time -import traceback +from abc import ABC, abstractmethod -import msgpack import zmq -from fastdeploy import envs -from fastdeploy.utils import zmq_client_logger +from fastdeploy.utils import llm_logger -class ZmqClient: +class ZmqClientBase(ABC): """ - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. """ - def __init__(self, name, mode): - self.context = zmq.Context(4) - self.socket = self.context.socket(mode) - self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"/dev/shm/router_{name}.ipc" + def __init__(self): + pass - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass - self.mutex = threading.Lock() - self.req_dict = dict() - self.router = None - self.poller = None - self.running = True + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + @abstractmethod def connect(self): """ Connect to the server using the file name specified in the constructor. """ - self.socket.connect(f"ipc://{self.file_name}") - - def start_server(self): - """ - Start the server using the file name specified in the constructor. - """ - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.socket.setsockopt(zmq.SNDTIMEO, -1) - self.socket.bind(f"ipc://{self.file_name}") - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - def create_router(self): - """ - Create a ROUTER socket and bind it to the specified router path. - """ - self.router = self.context.socket(zmq.ROUTER) - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) - self.router.setsockopt(zmq.SNDTIMEO, -1) - self.router.bind(f"ipc://{self.router_path}") - zmq_client_logger.info(f"router path: {self.router_path}") + pass def send_json(self, data): """ Send a JSON-serializable object over the socket. """ + self._ensure_socket() self.socket.send_json(data) def recv_json(self): """ Receive a JSON-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_json() def send_pyobj(self, data): """ Send a Pickle-serializable object over the socket. """ + self._ensure_socket() self.socket.send_pyobj(data) def recv_pyobj(self): """ Receive a Pickle-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_pyobj() - def pack_aggregated_data(self, data): - """ - Aggregate multiple responses into one and send them to the client. - """ - result = data[0] - if len(data) > 1: - for response in data[1:]: - result.add(response) - result = msgpack.packb([result.to_dict()]) - return result + @abstractmethod + def close(self): + pass - def send_multipart(self, req_id, data): - """ - Send a multipart message to the router socket. - """ - if self.router is None: - raise RuntimeError("Router socket not created. Call create_router() first.") - while self.running: - with self.mutex: - if req_id not in self.req_dict: - try: - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) - req_id_str = request_id.decode("utf-8") - self.req_dict[req_id_str] = client - except zmq.Again: - time.sleep(0.001) - continue - else: - break - if self.req_dict[req_id] == -1: - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - return - try: - start_send = time.time() - if self.aggregate_send: - result = self.pack_aggregated_data(data) - else: - result = msgpack.packb([response.to_dict() for response in data]) - self.router.send_multipart([self.req_dict[req_id], b"", result]) - zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") - except zmq.ZMQError as e: - zmq_client_logger.error(f"[{req_id}] zmq error: {e}") - self.req_dict[req_id] = -1 - except Exception as e: - zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") +class ZmqIpcClient(ZmqClientBase): + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.file_name = f"/dev/shm/{name}.socket" + self.context = zmq.Context() + self.socket = self.context.socket(self.mode) - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}") + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + return self.context.socket(self.mode) - def receive_json_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_json(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") - return str(e), None - - def receive_pyobj_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_pyobj(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") - return str(e), None - - def _clear_ipc(self, name): - """ - Remove the IPC file with the given name. - """ - if os.path.exists(name): - try: - os.remove(name) - except OSError as e: - zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}") + def connect(self): + self._ensure_socket() + self.socket.connect(f"ipc://{self.file_name}") def close(self): """ - Close the socket and context, and remove the IPC files. + Close the socket and context. """ - if not self.running: - return - - self.running = False - zmq_client_logger.info("Closing ZMQ connection...") + llm_logger.info("ZMQ client is closing connection...") try: - if hasattr(self, "socket") and not self.socket.closed: + if self.socket is not None and not self.socket.closed: + self.socket.setsockopt(zmq.LINGER, 0) self.socket.close() - - if self.router is not None and not self.router.closed: - self.router.close() - - if not self.context.closed: + if self.context is not None: self.context.term() - self._clear_ipc(self.file_name) - self._clear_ipc(self.router_path) except Exception as e: - zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") + llm_logger.warning(f"ZMQ client failed to close connection - {e}") return - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py new file mode 100644 index 000000000..72eb734c6 --- /dev/null +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -0,0 +1,335 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import threading +import time +from abc import ABC, abstractmethod +from collections import defaultdict + +import msgpack +import zmq + +from fastdeploy import envs +from fastdeploy.utils import llm_logger + + +class ZmqServerBase(ABC): + """ + ZmqServerBase + """ + + def __init__(self): + self.cached_results = defaultdict(list) + self.response_token_lock = threading.Lock() + + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass + + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + + def send_json(self, data): + """ + Send a JSON-serializable object over the socket. + """ + self._ensure_socket() + self.socket.send_json(data) + + def recv_json(self): + """ + Receive a JSON-serializable object from the socket. + """ + self._ensure_socket() + return self.socket.recv_json() + + def send_pyobj(self, data): + """ + Send a Pickle-serializable object over the socket. + """ + self._ensure_socket() + self.socket.send_pyobj(data) + + def recv_pyobj(self): + """ + Receive a Pickle-serializable object from the socket. + """ + self._ensure_socket() + return self.socket.recv_pyobj() + + def pack_aggregated_data(self, data): + """ + Aggregate multiple responses into one and send them to the client. + """ + result = data[0] + if len(data) > 1: + for response in data[1:]: + result.add(response) + result = msgpack.packb([result.to_dict()]) + return result + + def receive_json_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_json(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def receive_pyobj_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def recv_result_handle(self): + while True: + try: + with self.response_token_lock: + client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) + req_id_str = request_id.decode("utf-8") + need_send_after_finished_inference = False + with self.mutex: + self.req_dict[req_id_str] = client + if req_id_str in self.cached_results: + if self.cached_results[req_id_str][-1][-1].finished: + need_send_after_finished_inference = True + if need_send_after_finished_inference: + self.send_response(req_id_str, []) + llm_logger.info(f"send_multipart finished, req_id: {req_id_str}") + self.req_dict.pop(req_id_str, None) + + except zmq.Again: + time.sleep(0.001) + continue + except Exception as e: + llm_logger.error(f"recv_result_handle get unknown exception: {e}") + continue + + def send_response(self, req_id, data): + """ + Send generated token result to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created. Call create_router() first.") + new_data = [] + has_result_handle = False + with self.mutex: + if req_id not in self.req_dict: + self.cached_results[req_id].append(data) + else: + has_result_handle = True + if req_id in self.cached_results: + for history_data in self.cached_results[req_id]: + new_data.extend(history_data) + llm_logger.info( + f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}" + ) + del self.cached_results[req_id] + if has_result_handle: + try: + new_data.extend(data) + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(new_data) + else: + result = msgpack.packb([response.to_dict() for response in new_data]) + with self.response_token_lock: + self.socket.send_multipart([self.req_dict[req_id], b"", result]) + llm_logger.debug( + f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}" + ) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + if data and data[-1].finished: + with self.mutex: + if req_id in self.req_dict: + llm_logger.info(f"send_multipart finished, req_id: {req_id}") + self.req_dict.pop(req_id, None) + + @abstractmethod + def close(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class ZmqIpcServer(ZmqServerBase): + """ + ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0 + """ + + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.cached_results = defaultdict(list) + if mode == zmq.PULL: + self.file_name = f"/dev/shm/{name}.socket" + elif mode == zmq.ROUTER: + self.file_name = f"/dev/shm/router_{name}.ipc" + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"ipc://{self.file_name}") + return self.socket + + def _clear_ipc(self, name): + """ + Remove the IPC file with the given name. + """ + if os.path.exists(name): + try: + os.remove(name) + except OSError as e: + llm_logger.warning(f"Failed to remove IPC file {name} - {e}") + + def close(self): + """ + Close the socket and context, and remove the IPC files. + """ + if not self.running: + return + + self.running = False + llm_logger.info("ZMQ server is closing connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + self._clear_ipc(self.file_name) + except Exception as e: + llm_logger.warning(f"ZMQ server failed to close connection - {e}") + return + + +class ZmqTcpServer(ZmqServerBase): + """ + ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1 + """ + + def __init__(self, port, mode): + self.mode = mode + self.port = port + self.cached_results = defaultdict(list) + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + + self.mutex = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + self.response_token_lock = threading.Lock() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"tcp://*:{self.port}") + return self.socket + + def recv_control_cmd(self): + """ + Recieve control command from client + """ + self._ensure_socket() + try: + client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) + task = msgpack.unpackb(task_data) + task_id_str = task["task_id"] + except zmq.Again: + return None + with self.mutex: + self.req_dict[task_id_str] = client + return task + + def response_for_control_cmd(self, task_id, result): + """ + Send command result back to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created.") + try: + result = msgpack.packb(result) + self.socket.send_multipart([self.req_dict[task_id], b"", result]) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + with self.mutex: + self.req_dict.pop(task_id, None) + llm_logger.debug(f"response control cmd finished, task_id: {task_id}") + + def close(self): + """ + Close the socket and context. + """ + if not self.running: + return + + self.running = False + llm_logger.info("ZMQ server is closing connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + + except Exception as e: + llm_logger.warning(f"ZMQ server failed to close connection - {e}") + return diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 8ed26680a..7b3bef5de 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -78,7 +78,8 @@ else: update_inputs_v1, ) -from fastdeploy.inter_communicator import ZmqClient + +from fastdeploy.inter_communicator import ZmqIpcClient from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput @@ -160,7 +161,7 @@ def pre_process( ) -def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int): +def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int): """Split output_tokens and output""" assert zmq_client is not None, "zmq_client should not be None" output_tokens = output_tokens.reshape([-1]).numpy() @@ -187,7 +188,7 @@ def post_process_normal( block_size: int = 64, save_each_rank: bool = False, skip_save_output: bool = False, - zmq_client: ZmqClient = None, + zmq_client: ZmqIpcClient = None, ) -> ModelRunnerOutput: """Post-processing steps after completing a single token generation.""" # handle vl: @@ -389,7 +390,7 @@ def post_process( save_each_rank: bool = False, speculative_decoding: bool = False, skip_save_output: bool = False, - zmq_client: ZmqClient = None, + zmq_client: ZmqIpcClient = None, ) -> None: """Post-processing steps after completing a single token generation.""" if speculative_decoding: diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 5173e3398..57375941e 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -29,7 +29,7 @@ import zmq from fastdeploy import envs from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput -from fastdeploy.inter_communicator import IPCSignal, ZmqClient +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcServer from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger, spec_logger @@ -58,12 +58,11 @@ class TokenProcessor: self.split_connector = split_connector if envs.FD_USE_GET_SAVE_OUTPUT_V1: + llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}") - self.zmq_server = ZmqClient( + self.zmq_server = ZmqIpcServer( name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL ) - self.zmq_server.start_server() - self.zmq_server.create_router() self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob @@ -498,6 +497,7 @@ class TokenProcessor: metrics = RequestMetrics( arrival_time=task.arrival_time, inference_start_time=task.inference_start_time, + model_execute_time=time.time() - task.inference_start_time, first_token_time=time.time() - task.inference_start_time, time_in_queue=task.schedule_start_time - task.preprocess_end_time, preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time, @@ -510,6 +510,7 @@ class TokenProcessor: metrics = RequestMetrics( arrival_time=time.time(), request_start_time=task.arrival_time, + model_execute_time=time.time() - task.inference_start_time, ) self.number_of_output_tokens += len(token_ids) self._record_metrics(task, current_time, token_ids) diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index 5d79e5009..159dd447d 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -208,6 +208,9 @@ class LocalScheduler: """ return (token_num + block_size - 1) // block_size + def get_unhandled_request_num(self): + return len(self.ids) - self.ids_read_cursor + def get_requests( self, available_blocks, diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py new file mode 100644 index 000000000..7908a7517 --- /dev/null +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -0,0 +1,118 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time +import traceback + +# **Note**: Just for internal use +import zmq + +from fastdeploy.inter_communicator import ZmqTcpServer +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.utils import envs, get_logger + +logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log") + + +class InternalAdapter: + def __init__(self, cfg, engine, dp_rank): + self.cfg = cfg + self.engine = engine + self.dp_rank = dp_rank + recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently + self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) + self.recv_external_instruct_thread = threading.Thread( + target=self._recv_external_module_control_instruct, daemon=True + ) + self.recv_external_instruct_thread.start() + if cfg.splitwise_role != "mixed": + self.response_external_instruct_thread = threading.Thread( + target=self._response_external_module_control_instruct, daemon=True + ) + self.response_external_instruct_thread.start() + + def _get_current_server_info(self): + """ + Get resources information + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) + + available_block_num = self.engine.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "max_block_num": int(self.cfg.cache_config.total_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_model_len, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), + "available_batch": int(self.engine.resource_manager.available_batch()), + } + return server_info + + def _recv_external_module_control_instruct(self): + """ + Receive a multipart message from the control cmd socket. + """ + while True: + try: + with self.response_lock: + task = self.recv_control_cmd_server.recv_control_cmd() + if task is None: + time.sleep(0.001) + continue + logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + [], + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _response_external_module_control_instruct(self): + while True: + try: + result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + logger.info(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 43de576de..806c8cb75 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -75,7 +75,7 @@ import zmq from fastdeploy import envs from fastdeploy.input.ernie4_5_vl_processor import DataProcessor -from fastdeploy.inter_communicator import ZmqClient +from fastdeploy.inter_communicator import ZmqIpcClient from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.worker.model_runner_base import ModelRunnerBase @@ -171,7 +171,7 @@ class GPUModelRunner(ModelRunnerBase): self.zmq_client = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: logger.info(f"zmq client get_save_output_rank{local_rank}") - self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH) + self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH) self.zmq_client.connect() self.zmq_client.socket.SNDTIMEO = 3000 diff --git a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py new file mode 100644 index 000000000..6202f4bae --- /dev/null +++ b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py @@ -0,0 +1,251 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import queue +import shutil +import signal +import socket +import subprocess +import sys +import time + +import pytest + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..", "..")) +print("project_root", project_root) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient + +env = os.environ.copy() + +# Read ports from environment variables; use default values if not set +FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) +FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) +FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8234)) + +FD_ENABLE_INTERNAL_ADAPTER = int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "1")) +FD_ZMQ_RECV_REQUEST_SERVER_PORT = int(os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8204")) +FD_ZMQ_SEND_RESPONSE_SERVER_PORT = int(os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8205")) +FD_ZMQ_CONTROL_CMD_SERVER_PORTS = int(os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8206")) +FD_ZMQ_CONTROL_CMD_SERVER_PORT = FD_ZMQ_CONTROL_CMD_SERVER_PORTS + +env["FD_ENABLE_INTERNAL_ADAPTER"] = str(FD_ENABLE_INTERNAL_ADAPTER) +env["FD_ZMQ_RECV_REQUEST_SERVER_PORT"] = str(FD_ZMQ_RECV_REQUEST_SERVER_PORT) +env["FD_ZMQ_SEND_RESPONSE_SERVER_PORT"] = str(FD_ZMQ_SEND_RESPONSE_SERVER_PORT) +env["FD_ZMQ_CONTROL_CMD_SERVER_PORTS"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORTS) +env["FD_ZMQ_CONTROL_CMD_SERVER_PORT"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORT) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_ZMQ_RECV_REQUEST_SERVER_PORT, + FD_ZMQ_SEND_RESPONSE_SERVER_PORT, + FD_ZMQ_CONTROL_CMD_SERVER_PORT, +] + + +@pytest.fixture +def zmq_req_client(): + client = LLMReqClient("0.0.0.0", FD_ZMQ_RECV_REQUEST_SERVER_PORT, FD_ZMQ_SEND_RESPONSE_SERVER_PORT) + return client + + +@pytest.fixture +def zmq_control_client(): + client = LLMControlClient("0.0.0.0", FD_ZMQ_CONTROL_CMD_SERVER_PORT) + return client + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + for pid in output.splitlines(): + os.kill(int(pid), signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + try: + result = subprocess.run( + f"ps -ef -ww| grep {FD_CACHE_QUEUE_PORT} | grep -v grep", shell=True, capture_output=True, text=True + ) + for line in result.stdout.strip().split("\n"): + if not line: + continue + parts = line.split() + pid = int(parts[1]) # ps -ef 的第二列是 PID + print(f"Killing PID: {pid}") + os.kill(pid, signal.SIGKILL) + except Exception as e: + print(f"Failed to kill cache manager process: {e}") + + +def clean_ports(): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in PORTS_TO_CLEAN: + kill_process_on_port(port) + time.sleep(2) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports() + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-21b-a3b-bf16-paddle" + + log_path = "server.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "32768", + "--max-num-seqs", + "128", + "--quantization", + "wint4", + ] + + # Start subprocess in new process group + # 清除log目录 + if os.path.exists("log"): + shutil.rmtree("log") + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + env=env, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(300): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"API server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + clean_ports() + print(f"API server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +def test_request_and_response(zmq_req_client): + prompt_token_ids = [5300, 93956, 55791] + req_id = "test" + request = { + "req_id": req_id, + "request_id": req_id, + "min_tokens": 1, + "dp_rank": 0, # P实例 DP rank, 从当前环境变量里读取 + "prompt_token_ids": prompt_token_ids, + "prompt_token_ids_len": len(prompt_token_ids), + "eos_token_ids": [2], + "stop_token_ids": [2], + "max_dec_len": 32 * 1024, + "max_tokens": 32 * 1024, + "min_dec_len": 1, + "arrival_time": time.time(), + "preprocess_start_time": time.time(), + "preprocess_end_time": time.time(), + "messages": [], + "temperature": 0.8, + "penalty_score": 1.0, + "repetition_penalty": 1.0, + "presence_penalty": 0, + "top_p": 0.8, + "frequency_penalty": 0.0, + } + result_queue = queue.Queue() + zmq_req_client.start(result_queue) + zmq_req_client.send_request(request) + zmq_req_client.request_result(req_id) + has_is_end_result = False + while True: + result = result_queue.get() + if result[-1]["finished"]: + has_is_end_result = True + break + assert has_is_end_result is True + + +def test_control_cmd(zmq_control_client): + result = zmq_control_client.get_payload() + assert "unhandled_request_num" in result + result = zmq_control_client.get_metrics() + assert result is not None diff --git a/tests/ci_use/EB_Lite_with_adapter/zmq_client.py b/tests/ci_use/EB_Lite_with_adapter/zmq_client.py new file mode 100644 index 000000000..db811d04a --- /dev/null +++ b/tests/ci_use/EB_Lite_with_adapter/zmq_client.py @@ -0,0 +1,121 @@ +import threading +import time +import uuid +from threading import Event + +import msgpack +import zmq + + +class LLMReqClient: + """ + LLM request client + """ + + def __init__(self, ip, send_req_server_port, recv_res_server_port): + self.ZMQ_SNDHWM = 64 * 1024 + self.context = zmq.Context() + self.send_req_client = self.context.socket(zmq.PUSH) + self.recv_res_client = self.context.socket(zmq.DEALER) + self.send_req_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.send_req_client.setsockopt(zmq.SNDTIMEO, -1) + self.recv_res_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.recv_res_client.setsockopt(zmq.SNDTIMEO, -1) + self.send_req_client.connect(f"tcp://{ip}:{send_req_server_port}") + self.recv_res_client.connect(f"tcp://{ip}:{recv_res_server_port}") + self.need_exit = False + self.response_socket_lock = threading.Lock() + + def send_request(self, req_data): + self.send_req_client.send_json(req_data) + + def request_result(self, req_id): + with self.response_socket_lock: + print(f"request result data for {req_id}") + self.recv_res_client.send_multipart([b"", req_id.encode("utf-8")]) + + def consume_results(self, result_queue): + while True: + try: + try: + with self.response_socket_lock: + frames = self.recv_res_client.recv_multipart(flags=zmq.NOBLOCK) + except zmq.Again: + time.sleep(0.001) + continue + data = frames[-1] + response = msgpack.unpackb(data) + print(f"get result data {response}") + result_queue.put(response) + if self.need_exit: + break + except Exception as e: + print(f"zmq client occured error {e} type: {type(e)} frames: {frames}") + + def start(self, result_queue): + threading.Thread(target=self.consume_results, args=(result_queue,), daemon=True).start() + + def exit(self): + print("exit") + self.need_exit = True + + +class LLMControlClient: + """ + LLM control client + """ + + def __init__(self, ip, port): + self.ZMQ_SNDHWM = 64 * 1024 + self.context = zmq.Context() + self.control_client = self.context.socket(zmq.DEALER) + self.control_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.control_client.setsockopt(zmq.SNDTIMEO, -1) + self.control_client.connect(f"tcp://{ip}:{port}") + self.task_event = {} + self.result = {} + self.response_socket_lock = threading.Lock() + threading.Thread(target=self.recv_results, daemon=True).start() + + def get_payload(self): + task_id = f"get_payload_{uuid.uuid4()}" + task = {"task_id": task_id, "cmd": "get_payload"} + self.task_event[task_id] = Event() + payload = msgpack.packb(task) + with self.response_socket_lock: + self.control_client.send_multipart([b"", payload]) + self.task_event[task_id].wait() + result = self.result[task_id] + del self.result[task_id] + del self.task_event[task_id] + return result + + def get_metrics(self): + task_id = f"get_metrics_{uuid.uuid4()}" + task = {"task_id": task_id, "cmd": "get_metrics"} + self.task_event[task_id] = Event() + payload = msgpack.packb(task) + with self.response_socket_lock: + self.control_client.send_multipart([b"", payload]) + self.task_event[task_id].wait() + result = self.result[task_id] + del self.result[task_id] + del self.task_event[task_id] + return result + + def recv_results(self): + while True: + try: + try: + with self.response_socket_lock: + frames = self.control_client.recv_multipart(flags=zmq.NOBLOCK) + except zmq.Again: + time.sleep(0.001) + continue + data = frames[-1] + result = msgpack.unpackb(data) + task_id = result["task_id"] + self.result[task_id] = result["result"] + self.task_event[task_id].set() + except Exception as e: + print(f"zmq client occured error {e} type: {type(e)} frames: {frames}")