From e8af92aab78e921c9469fae6c7366f9347f9e97b Mon Sep 17 00:00:00 2001 From: chenjian <1435317881@qq.com> Date: Sat, 23 Aug 2025 09:56:47 +0800 Subject: [PATCH] [Feature] Support mixed deployment with yiyan adapter (#3533) * [Feature] Support mixed deployment with yiyan adapter * [Feature] Support mixed deployment with yiyan adapter * fix merge --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- fastdeploy/engine/engine.py | 36 ++- fastdeploy/entrypoints/engine_client.py | 4 +- fastdeploy/envs.py | 8 + fastdeploy/inter_communicator/__init__.py | 5 +- fastdeploy/inter_communicator/zmq_client.py | 194 ++--------- fastdeploy/inter_communicator/zmq_server.py | 302 ++++++++++++++++++ fastdeploy/scheduler/local_scheduler.py | 3 + .../splitwise/internal_adapter_utils.py | 117 +++++++ 8 files changed, 494 insertions(+), 175 deletions(-) create mode 100644 fastdeploy/inter_communicator/zmq_server.py create mode 100644 fastdeploy/splitwise/internal_adapter_utils.py diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 56161e30c..0c52cbfc5 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -47,12 +47,14 @@ from fastdeploy.inter_communicator import ( EngineCacheQueue, EngineWorkerQueue, IPCSignal, - ZmqClient, + ZmqIpcServer, + ZmqTcpServer, ) from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -181,9 +183,19 @@ class LLMEngine: self.data_processor = self.input_processor.create_processor() if api_server_pid is not None: - self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) - self.zmq_server.start_server() - self.zmq_server.create_router() + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) + self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) + self.external_adapter = InternalAdapter( + cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node + ) + else: + self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) + self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) + self.recv_result_handle_thread = threading.Thread( + target=self.send_response_server.recv_result_handle, daemon=True + ) + self.recv_result_handle_thread.start() time.sleep(3) if self.do_profile == 0 and ( @@ -293,7 +305,7 @@ class LLMEngine: time.sleep(0.005) continue for request_id, contents in results.items(): - self.zmq_server.send_multipart(request_id, contents) + self.send_response_server.send_response(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") @@ -422,9 +434,9 @@ class LLMEngine: try: block = True if len(added_requests) == 0 else False if not self.cfg.enable_mm: - err, data = self.zmq_server.receive_json_once(block) + err, data = self.recv_request_server.receive_json_once(block) else: - err, data = self.zmq_server.receive_pyobj_once(block) + err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") break @@ -472,7 +484,7 @@ class LLMEngine: ) # Since the request is not in scheduler # Send result by zmq directly - self.zmq_server.send_multipart(request_id, error_result) + self.send_response_server.send_response(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " @@ -1009,8 +1021,12 @@ class LLMEngine: print(f"Error extracting sub services: {e}") self.engine_worker_queue.cleanup() - if hasattr(self, "zmq_server") and self.zmq_server is not None: - self.zmq_server.close() + if hasattr(self, "send_response_server") and self.send_response_server is not None: + self.send_response_server.close() + if hasattr(self, "recv_request_server") and self.recv_request_server is not None: + self.recv_request_server.close() + if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: + self.recv_control_cmd_server.close() if hasattr(self, "dp_processed"): for p in self.dp_processed: p.join() diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 45b11f914..2d4b61b1f 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -21,7 +21,7 @@ import numpy as np from fastdeploy import envs from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal, ZmqClient +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.platforms import current_platform from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger @@ -85,7 +85,7 @@ class EngineClient: """ Create a ZMQ client. """ - self.zmq_client = ZmqClient(model, mode) + self.zmq_client = ZmqIpcClient(model, mode) self.zmq_client.connect() def format_and_add_data(self, prompts: dict): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index f5aa5dc7e..5551c69f9 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -84,6 +84,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections "FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, + # enable internal module to access LLMEngine. + "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), + # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), + # LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), + # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), } diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 0c1cc0d9f..ea08af31a 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -17,6 +17,7 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue from .ipc_signal import IPCSignal -from .zmq_client import ZmqClient +from .zmq_client import ZmqIpcClient +from .zmq_server import ZmqIpcServer, ZmqTcpServer -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"] +__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"] diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 5a9b6418d..13242f2a2 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -14,206 +14,78 @@ # limitations under the License. """ -import os -import threading -import time +from abc import ABC, abstractmethod -import msgpack import zmq -from fastdeploy import envs -from fastdeploy.utils import llm_logger - -class ZmqClient: +class ZmqClientBase(ABC): """ - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. """ - def __init__(self, name, mode): - self.context = zmq.Context() - self.socket = self.context.socket(mode) - self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"/dev/shm/router_{name}.ipc" + def __init__(self): + pass - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass - self.mutex = threading.Lock() - self.req_dict = dict() - self.router = None - self.poller = None - self.running = True + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + @abstractmethod def connect(self): """ Connect to the server using the file name specified in the constructor. """ - self.socket.connect(f"ipc://{self.file_name}") - - def start_server(self): - """ - Start the server using the file name specified in the constructor. - """ - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.socket.setsockopt(zmq.SNDTIMEO, -1) - self.socket.bind(f"ipc://{self.file_name}") - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - def create_router(self): - """ - Create a ROUTER socket and bind it to the specified router path. - """ - self.router = self.context.socket(zmq.ROUTER) - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) - self.router.setsockopt(zmq.SNDTIMEO, -1) - self.router.bind(f"ipc://{self.router_path}") + pass def send_json(self, data): """ Send a JSON-serializable object over the socket. """ + self._ensure_socket() self.socket.send_json(data) def recv_json(self): """ Receive a JSON-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_json() def send_pyobj(self, data): """ Send a Pickle-serializable object over the socket. """ + self._ensure_socket() self.socket.send_pyobj(data) def recv_pyobj(self): """ Receive a Pickle-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_pyobj() - def pack_aggregated_data(self, data): - """ - Aggregate multiple responses into one and send them to the client. - """ - result = data[0] - if len(data) > 1: - for response in data[1:]: - result.add(response) - result = msgpack.packb([result.to_dict()]) - return result - def send_multipart(self, req_id, data): - """ - Send a multipart message to the router socket. - """ - if self.router is None: - raise RuntimeError("Router socket not created. Call create_router() first.") - while self.running: - with self.mutex: - if req_id not in self.req_dict: - try: - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) - req_id_str = request_id.decode("utf-8") - self.req_dict[req_id_str] = client - except zmq.Again: - time.sleep(0.001) - continue - else: - break - if self.req_dict[req_id] == -1: - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - return - try: - start_send = time.time() - if self.aggregate_send: - result = self.pack_aggregated_data(data) - else: - result = msgpack.packb([response.to_dict() for response in data]) - self.router.send_multipart([self.req_dict[req_id], b"", result]) - llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") - except zmq.ZMQError as e: - llm_logger.error(f"[{req_id}] zmq error: {e}") - self.req_dict[req_id] = -1 - except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}") +class ZmqIpcClient(ZmqClientBase): + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.file_name = f"/dev/shm/{name}.socket" + self.context = zmq.Context() + self.socket = self.context.socket(self.mode) - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - llm_logger.info(f"send_multipart finished, req_id: {req_id}") + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + return self.context.socket(self.mode) - def receive_json_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_json(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None - - def receive_pyobj_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_pyobj(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None - - def _clear_ipc(self, name): - """ - Remove the IPC file with the given name. - """ - if os.path.exists(name): - try: - os.remove(name) - except OSError as e: - llm_logger.warning(f"Failed to remove IPC file {name} - {e}") - - def close(self): - """ - Close the socket and context, and remove the IPC files. - """ - if not self.running: - return - - self.running = False - llm_logger.info("Closing ZMQ connection...") - try: - if hasattr(self, "socket") and not self.socket.closed: - self.socket.close() - - if self.router is not None and not self.router.closed: - self.router.close() - - if not self.context.closed: - self.context.term() - - self._clear_ipc(self.file_name) - self._clear_ipc(self.router_path) - except Exception as e: - llm_logger.warning(f"Failed to close ZMQ connection - {e}") - return - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + def connect(self): + self._ensure_socket() + self.socket.connect(f"ipc://{self.file_name}") diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py new file mode 100644 index 000000000..56488d53a --- /dev/null +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -0,0 +1,302 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import threading +import time +from abc import ABC, abstractmethod +from collections import defaultdict + +import msgpack +import zmq + +from fastdeploy import envs +from fastdeploy.utils import llm_logger + + +class ZmqServerBase(ABC): + """ + ZmqServerBase + """ + + def __init__(self): + self.cached_results = defaultdict(list) + self.response_token_lock = threading.Lock() + + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass + + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + + def pack_aggregated_data(self, data): + """ + Aggregate multiple responses into one and send them to the client. + """ + result = data[0] + if len(data) > 1: + for response in data[1:]: + result.add(response) + result = msgpack.packb([result.to_dict()]) + return result + + def receive_json_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_json(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def receive_pyobj_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def recv_result_handle(self): + while True: + try: + with self.response_token_lock: + client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) + req_id_str = request_id.decode("utf-8") + with self.mutex: + self.req_dict[req_id_str] = client + except zmq.Again: + time.sleep(0.001) + continue + except Exception as e: + llm_logger.error(f"recv_result_handle get unknown exception: {e}") + continue + + def send_response(self, req_id, data): + """ + Send generated token result to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created. Call create_router() first.") + new_data = [] + has_result_handle = False + with self.mutex: + if req_id not in self.req_dict: + self.cached_results[req_id].append(data) + else: + has_result_handle = True + if req_id in self.cached_results: + for history_data in self.cached_results[req_id]: + new_data.extend(history_data) + llm_logger.info( + f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}" + ) + del self.cached_results[req_id] + if has_result_handle: + try: + new_data.extend(data) + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(new_data) + else: + result = msgpack.packb([response.to_dict() for response in new_data]) + with self.response_token_lock: + self.socket.send_multipart([self.req_dict[req_id], b"", result]) + llm_logger.debug( + f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}" + ) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + if data[-1].finished: + with self.mutex: + if req_id not in self.req_dict: + llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it") + if req_id in self.cached_results: + del self.cached_results[req_id] + else: + llm_logger.info(f"send_multipart finished, req_id: {req_id}") + self.req_dict.pop(req_id, None) + + @abstractmethod + def close(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class ZmqIpcServer(ZmqServerBase): + """ + ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0 + """ + + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.cached_results = defaultdict(list) + if mode == zmq.PULL: + self.file_name = f"/dev/shm/{name}.socket" + elif mode == zmq.ROUTER: + self.file_name = f"/dev/shm/router_{name}.ipc" + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"ipc://{self.file_name}") + return self.socket + + def _clear_ipc(self, name): + """ + Remove the IPC file with the given name. + """ + if os.path.exists(name): + try: + os.remove(name) + except OSError as e: + llm_logger.warning(f"Failed to remove IPC file {name} - {e}") + + def close(self): + """ + Close the socket and context, and remove the IPC files. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + self._clear_ipc(self.file_name) + except Exception as e: + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return + + +class ZmqTcpServer(ZmqServerBase): + """ + ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1 + """ + + def __init__(self, port, mode): + self.mode = mode + self.port = port + self.cached_results = defaultdict(list) + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + + self.mutex = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + self.response_token_lock = threading.Lock() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"tcp://*:{self.port}") + return self.socket + + def recv_control_cmd(self): + """ + Recieve control command from client + """ + self._ensure_socket() + try: + client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) + task = msgpack.unpackb(task_data) + task_id_str = task["task_id"] + except zmq.Again: + return None + with self.mutex: + self.req_dict[task_id_str] = client + return task + + def response_for_control_cmd(self, task_id, result): + """ + Send command result back to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created.") + try: + result = msgpack.packb(result) + self.socket.send_multipart([self.req_dict[task_id], b"", result]) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + with self.mutex: + self.req_dict.pop(task_id, None) + llm_logger.debug(f"response control cmd finished, task_id: {task_id}") + + def close(self): + """ + Close the socket and context. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + + except Exception as e: + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index 5d79e5009..20e53317b 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -208,6 +208,9 @@ class LocalScheduler: """ return (token_num + block_size - 1) // block_size + def get_unhandled_request_num(self): + return len(self.requests) + def get_requests( self, available_blocks, diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py new file mode 100644 index 000000000..d52edf897 --- /dev/null +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -0,0 +1,117 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time +import traceback + +# **Note**: Just for internal use +import zmq + +from fastdeploy.inter_communicator import ZmqTcpServer +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.utils import envs, get_logger + +logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log") + + +class InternalAdapter: + def __init__(self, cfg, engine, dp_rank): + self.cfg = cfg + self.engine = engine + self.dp_rank = dp_rank + recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently + self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) + self.recv_external_instruct_thread = threading.Thread( + target=self._recv_external_module_control_instruct, daemon=True + ) + self.recv_external_instruct_thread.start() + self.response_external_instruct_thread = threading.Thread( + target=self._response_external_module_control_instruct, daemon=True + ) + self.response_external_instruct_thread.start() + + def _get_current_server_info(self): + """ + Get resources information + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) + + available_block_num = self.engine.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "max_block_num": int(self.cfg.cache_config.total_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_num_batched_tokens, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), + "available_batch": int(self.engine.resource_manager.available_batch()), + } + return server_info + + def _recv_external_module_control_instruct(self): + """ + Receive a multipart message from the control cmd socket. + """ + while True: + try: + with self.response_lock: + task = self.recv_control_cmd_server.recv_control_cmd() + if task is None: + time.sleep(0.001) + continue + logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + [], + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _response_external_module_control_instruct(self): + while True: + try: + result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + logger.info(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")