mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] support adapter (#4180)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] support adapter * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -37,12 +37,14 @@ from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
ZmqClient,
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, envs, llm_logger
|
||||
|
||||
@@ -571,10 +573,21 @@ class EngineSevice:
|
||||
def start_zmq_service(self, api_server_pid=None):
|
||||
if api_server_pid is None:
|
||||
return
|
||||
self.api_server_pid = api_server_pid
|
||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.external_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
|
||||
time.sleep(3)
|
||||
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
||||
self.insert_task_to_scheduler_thread.start()
|
||||
@@ -588,9 +601,9 @@ class EngineSevice:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
if not self.cfg.model_config.enable_mm:
|
||||
err, data = self.zmq_server.receive_json_once(block)
|
||||
err, data = self.recv_request_server.receive_json_once(block)
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
@@ -644,7 +657,7 @@ class EngineSevice:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.zmq_server.send_multipart(request_id, [error_result])
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"Error happend while receving new request from zmq, details={e}, "
|
||||
@@ -662,7 +675,7 @@ class EngineSevice:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -110,7 +110,7 @@ class EngineClient:
|
||||
"""
|
||||
Create a ZMQ client.
|
||||
"""
|
||||
self.zmq_client = ZmqClient(model, mode)
|
||||
self.zmq_client = ZmqIpcClient(model, mode)
|
||||
self.zmq_client.connect()
|
||||
|
||||
async def format_and_add_data(self, prompts: dict):
|
||||
|
@@ -95,6 +95,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
||||
# force disable default chunked prefill
|
||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
|
||||
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -17,6 +17,15 @@
|
||||
from .engine_cache_queue import EngineCacheQueue
|
||||
from .engine_worker_queue import EngineWorkerQueue
|
||||
from .ipc_signal import IPCSignal, shared_memory_exists
|
||||
from .zmq_client import ZmqClient
|
||||
from .zmq_client import ZmqIpcClient
|
||||
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||
|
||||
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]
|
||||
__all__ = [
|
||||
"ZmqIpcClient",
|
||||
"IPCSignal",
|
||||
"EngineWorkerQueue",
|
||||
"EngineCacheQueue",
|
||||
"ZmqTcpServer",
|
||||
"ZmqIpcServer",
|
||||
"shared_memory_exists",
|
||||
]
|
||||
|
@@ -14,209 +14,78 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import zmq_client_logger
|
||||
|
||||
|
||||
class ZmqClient:
|
||||
class ZmqClientBase(ABC):
|
||||
"""
|
||||
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.context = zmq.Context(4)
|
||||
self.socket = self.context.socket(mode)
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.router = None
|
||||
self.poller = None
|
||||
self.running = True
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
||||
def start_server(self):
|
||||
"""
|
||||
Start the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
def create_router(self):
|
||||
"""
|
||||
Create a ROUTER socket and bind it to the specified router path.
|
||||
"""
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind(f"ipc://{self.router_path}")
|
||||
zmq_client_logger.info(f"router path: {self.router_path}")
|
||||
pass
|
||||
|
||||
def send_json(self, data):
|
||||
"""
|
||||
Send a JSON-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_json(data)
|
||||
|
||||
def recv_json(self):
|
||||
"""
|
||||
Receive a JSON-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_json()
|
||||
|
||||
def send_pyobj(self, data):
|
||||
"""
|
||||
Send a Pickle-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_pyobj(data)
|
||||
|
||||
def recv_pyobj(self):
|
||||
"""
|
||||
Receive a Pickle-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
"""
|
||||
if self.router is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
class ZmqIpcClient(ZmqClientBase):
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(self.mode)
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if self.req_dict[req_id] == -1:
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
return
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b"", result])
|
||||
zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
except zmq.ZMQError as e:
|
||||
zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
|
||||
self.req_dict[req_id] = -1
|
||||
except Exception as e:
|
||||
zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.context = zmq.Context()
|
||||
return self.context.socket(self.mode)
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
||||
return str(e), None
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
zmq_client_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if hasattr(self, "socket") and not self.socket.closed:
|
||||
self.socket.close()
|
||||
|
||||
if self.router is not None and not self.router.closed:
|
||||
self.router.close()
|
||||
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
self._clear_ipc(self.file_name)
|
||||
self._clear_ipc(self.router_path)
|
||||
except Exception as e:
|
||||
zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
|
||||
return
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
def connect(self):
|
||||
self._ensure_socket()
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
304
fastdeploy/inter_communicator/zmq_server.py
Normal file
304
fastdeploy/inter_communicator/zmq_server.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqServerBase(ABC):
|
||||
"""
|
||||
ZmqServerBase
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cached_results = defaultdict(list)
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
result = self.socket.recv_pyobj(flags=flags)
|
||||
llm_logger.info(f"receive one pyobj {result}")
|
||||
return None, result
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def recv_result_handle(self):
|
||||
while True:
|
||||
try:
|
||||
with self.response_token_lock:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except Exception as e:
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
new_data = []
|
||||
has_result_handle = False
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
self.cached_results[req_id].append(data)
|
||||
else:
|
||||
has_result_handle = True
|
||||
if req_id in self.cached_results:
|
||||
for history_data in self.cached_results[req_id]:
|
||||
new_data.extend(history_data)
|
||||
llm_logger.info(
|
||||
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
|
||||
)
|
||||
del self.cached_results[req_id]
|
||||
if has_result_handle:
|
||||
try:
|
||||
new_data.extend(data)
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(new_data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||
with self.response_token_lock:
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(
|
||||
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
|
||||
if req_id in self.cached_results:
|
||||
del self.cached_results[req_id]
|
||||
else:
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
self.req_dict.pop(req_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ZmqIpcServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.cached_results = defaultdict(list)
|
||||
if mode == zmq.PULL:
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
elif mode == zmq.ROUTER:
|
||||
self.file_name = f"/dev/shm/router_{name}.ipc"
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
return self.socket
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
self._clear_ipc(self.file_name)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
|
||||
class ZmqTcpServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"""
|
||||
|
||||
def __init__(self, port, mode):
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.cached_results = defaultdict(list)
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"tcp://*:{self.port}")
|
||||
return self.socket
|
||||
|
||||
def recv_control_cmd(self):
|
||||
"""
|
||||
Recieve control command from client
|
||||
"""
|
||||
self._ensure_socket()
|
||||
try:
|
||||
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
task = msgpack.unpackb(task_data)
|
||||
task_id_str = task["task_id"]
|
||||
except zmq.Again:
|
||||
return None
|
||||
with self.mutex:
|
||||
self.req_dict[task_id_str] = client
|
||||
return task
|
||||
|
||||
def response_for_control_cmd(self, task_id, result):
|
||||
"""
|
||||
Send command result back to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created.")
|
||||
try:
|
||||
result = msgpack.packb(result)
|
||||
self.socket.send_multipart([self.req_dict[task_id], b"", result])
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
with self.mutex:
|
||||
self.req_dict.pop(task_id, None)
|
||||
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
@@ -281,6 +281,9 @@ class LocalScheduler:
|
||||
|
||||
return requests
|
||||
|
||||
def get_unhandled_request_num(self):
|
||||
return len(self.ids) - self.ids_read_cursor
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
Add processing results back to the scheduler.
|
||||
|
117
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
117
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# **Note**: Just for internal use
|
||||
import zmq
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqTcpServer
|
||||
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
|
||||
|
||||
|
||||
class InternalAdapter:
|
||||
def __init__(self, cfg, engine, dp_rank):
|
||||
self.cfg = cfg
|
||||
self.engine = engine
|
||||
self.dp_rank = dp_rank
|
||||
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
|
||||
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
|
||||
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
|
||||
self.recv_external_instruct_thread = threading.Thread(
|
||||
target=self._recv_external_module_control_instruct, daemon=True
|
||||
)
|
||||
self.recv_external_instruct_thread.start()
|
||||
# self.response_external_instruct_thread = threading.Thread(
|
||||
# target=self._response_external_module_control_instruct, daemon=True
|
||||
# )
|
||||
# self.response_external_instruct_thread.start()
|
||||
|
||||
def _get_current_server_info(self):
|
||||
"""
|
||||
Get resources information
|
||||
"""
|
||||
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||
|
||||
available_block_num = self.engine.resource_manager.available_block_num()
|
||||
server_info = {
|
||||
"splitwise_role": self.cfg.splitwise_role,
|
||||
"block_size": int(self.cfg.cache_config.block_size),
|
||||
"block_num": int(available_block_num),
|
||||
"max_block_num": int(self.cfg.cache_config.total_block_num),
|
||||
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
|
||||
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
||||
"max_batch_size": int(available_batch_size),
|
||||
"max_input_token_num": self.cfg.max_num_batched_tokens,
|
||||
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||
"available_batch": int(self.engine.resource_manager.available_batch()),
|
||||
}
|
||||
return server_info
|
||||
|
||||
def _recv_external_module_control_instruct(self):
|
||||
"""
|
||||
Receive a multipart message from the control cmd socket.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
with self.response_lock:
|
||||
task = self.recv_control_cmd_server.recv_control_cmd()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"Recieve control task: {task}")
|
||||
task_id_str = task["task_id"]
|
||||
if task["cmd"] == "get_payload":
|
||||
payload_info = self._get_current_server_info()
|
||||
result = {"task_id": task_id_str, "result": payload_info}
|
||||
logger.debug(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
|
||||
elif task["cmd"] == "get_metrics":
|
||||
metrics_text = get_filtered_metrics(
|
||||
[],
|
||||
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
|
||||
)
|
||||
result = {"task_id": task_id_str, "result": metrics_text}
|
||||
logger.debug(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
elif task["cmd"] == "connect_rdma":
|
||||
self.engine.engine_worker_queue.put_connect_rdma_task(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def _response_external_module_control_instruct(self):
|
||||
while True:
|
||||
try:
|
||||
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
|
||||
if result_data:
|
||||
task_id_str = result_data["task_id"]
|
||||
result = {"task_id": task_id_str, "result": result_data}
|
||||
logger.info(f"Response for task: {task_id_str}")
|
||||
with self.response_lock:
|
||||
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")
|
Reference in New Issue
Block a user