mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 21:02:24 +08:00
[Feature] Support mixed deployment with yiyan adapter in develop (#3976)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * fix ci * fix for eb5 * fix ci * fix ci * fix ci --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
3
.github/workflows/_pre_ce_test.yml
vendored
3
.github/workflows/_pre_ce_test.yml
vendored
@@ -82,6 +82,9 @@ jobs:
|
|||||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||||
|
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
|
||||||
|
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
|
||||||
|
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
|
||||||
echo "Test ENV Parameter:"
|
echo "Test ENV Parameter:"
|
||||||
echo "========================================================="
|
echo "========================================================="
|
||||||
echo "FLASK_PORT=${FLASK_PORT}"
|
echo "FLASK_PORT=${FLASK_PORT}"
|
||||||
|
@@ -37,12 +37,14 @@ from fastdeploy.inter_communicator import (
|
|||||||
EngineCacheQueue,
|
EngineCacheQueue,
|
||||||
EngineWorkerQueue,
|
EngineWorkerQueue,
|
||||||
IPCSignal,
|
IPCSignal,
|
||||||
ZmqClient,
|
ZmqIpcServer,
|
||||||
|
ZmqTcpServer,
|
||||||
)
|
)
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||||
|
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||||
from fastdeploy.utils import EngineError, envs, llm_logger
|
from fastdeploy.utils import EngineError, envs, llm_logger
|
||||||
|
|
||||||
@@ -576,9 +578,19 @@ class EngineService:
|
|||||||
if api_server_pid is None:
|
if api_server_pid is None:
|
||||||
return
|
return
|
||||||
self.api_server_pid = api_server_pid
|
self.api_server_pid = api_server_pid
|
||||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
self.zmq_server.start_server()
|
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||||
self.zmq_server.create_router()
|
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||||
|
self.internal_adapter = InternalAdapter(
|
||||||
|
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||||
|
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||||
|
self.recv_result_handle_thread = threading.Thread(
|
||||||
|
target=self.send_response_server.recv_result_handle, daemon=True
|
||||||
|
)
|
||||||
|
self.recv_result_handle_thread.start()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
||||||
self.insert_task_to_scheduler_thread.start()
|
self.insert_task_to_scheduler_thread.start()
|
||||||
@@ -592,9 +604,9 @@ class EngineService:
|
|||||||
try:
|
try:
|
||||||
block = True if len(added_requests) == 0 else False
|
block = True if len(added_requests) == 0 else False
|
||||||
if not self.cfg.model_config.enable_mm:
|
if not self.cfg.model_config.enable_mm:
|
||||||
err, data = self.zmq_server.receive_json_once(block)
|
err, data = self.recv_request_server.receive_json_once(block)
|
||||||
else:
|
else:
|
||||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
||||||
break
|
break
|
||||||
@@ -648,7 +660,7 @@ class EngineService:
|
|||||||
)
|
)
|
||||||
# Since the request is not in scheduler
|
# Since the request is not in scheduler
|
||||||
# Send result by zmq directly
|
# Send result by zmq directly
|
||||||
self.zmq_server.send_multipart(request_id, [error_result])
|
self.send_response_server.send_response(request_id, [error_result])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(
|
llm_logger.error(
|
||||||
f"Error happened while receiving new request from zmq, details={e}, "
|
f"Error happened while receiving new request from zmq, details={e}, "
|
||||||
@@ -666,7 +678,7 @@ class EngineService:
|
|||||||
time.sleep(0.005)
|
time.sleep(0.005)
|
||||||
continue
|
continue
|
||||||
for request_id, contents in results.items():
|
for request_id, contents in results.items():
|
||||||
self.zmq_server.send_multipart(request_id, contents)
|
self.send_response_server.send_response(request_id, contents)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
|
llm_logger.error(f"Unexcepted error happened: {e}, {traceback.format_exc()!s}")
|
||||||
@@ -766,5 +778,9 @@ class EngineService:
|
|||||||
self.worker_healthy_live_signal.clear()
|
self.worker_healthy_live_signal.clear()
|
||||||
self.exist_prefill_task_signal.clear()
|
self.exist_prefill_task_signal.clear()
|
||||||
self.model_weights_status_signal.clear()
|
self.model_weights_status_signal.clear()
|
||||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
||||||
self.zmq_server.close()
|
self.send_response_server.close()
|
||||||
|
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
||||||
|
self.recv_request_server.close()
|
||||||
|
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
||||||
|
self.recv_control_cmd_server.close()
|
||||||
|
@@ -27,7 +27,7 @@ from fastdeploy.config import ModelConfig
|
|||||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||||
from fastdeploy.input.preprocess import InputPreprocessor
|
from fastdeploy.input.preprocess import InputPreprocessor
|
||||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
@@ -115,7 +115,7 @@ class EngineClient:
|
|||||||
"""
|
"""
|
||||||
Create a ZMQ client.
|
Create a ZMQ client.
|
||||||
"""
|
"""
|
||||||
self.zmq_client = ZmqClient(model, mode)
|
self.zmq_client = ZmqIpcClient(model, mode)
|
||||||
self.zmq_client.connect()
|
self.zmq_client.connect()
|
||||||
|
|
||||||
async def format_and_add_data(self, prompts: dict):
|
async def format_and_add_data(self, prompts: dict):
|
||||||
|
@@ -98,6 +98,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to use new get_output and save_output method (0 or 1)
|
# Whether to use new get_output and save_output method (0 or 1)
|
||||||
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
||||||
# Whether to enable model cache feature
|
# Whether to enable model cache feature
|
||||||
|
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
|
||||||
|
# enable internal module to access LLMEngine.
|
||||||
|
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
|
||||||
|
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||||
|
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||||
|
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||||
|
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||||
|
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||||
|
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||||
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -17,6 +17,15 @@
|
|||||||
from .engine_cache_queue import EngineCacheQueue
|
from .engine_cache_queue import EngineCacheQueue
|
||||||
from .engine_worker_queue import EngineWorkerQueue
|
from .engine_worker_queue import EngineWorkerQueue
|
||||||
from .ipc_signal import IPCSignal, shared_memory_exists
|
from .ipc_signal import IPCSignal, shared_memory_exists
|
||||||
from .zmq_client import ZmqClient
|
from .zmq_client import ZmqIpcClient
|
||||||
|
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||||
|
|
||||||
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]
|
__all__ = [
|
||||||
|
"ZmqIpcClient",
|
||||||
|
"IPCSignal",
|
||||||
|
"EngineWorkerQueue",
|
||||||
|
"EngineCacheQueue",
|
||||||
|
"ZmqTcpServer",
|
||||||
|
"ZmqIpcServer",
|
||||||
|
"shared_memory_exists",
|
||||||
|
]
|
||||||
|
@@ -14,209 +14,100 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
from abc import ABC, abstractmethod
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import msgpack
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy.utils import llm_logger
|
||||||
from fastdeploy.utils import zmq_client_logger
|
|
||||||
|
|
||||||
|
|
||||||
class ZmqClient:
|
class ZmqClientBase(ABC):
|
||||||
"""
|
"""
|
||||||
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, mode):
|
def __init__(self):
|
||||||
self.context = zmq.Context(4)
|
pass
|
||||||
self.socket = self.context.socket(mode)
|
|
||||||
self.file_name = f"/dev/shm/{name}.socket"
|
|
||||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
|
||||||
|
|
||||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
@abstractmethod
|
||||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
def _create_socket(self):
|
||||||
|
"""Abstract method to create and return a ZeroMQ socket."""
|
||||||
|
pass
|
||||||
|
|
||||||
self.mutex = threading.Lock()
|
def _ensure_socket(self):
|
||||||
self.req_dict = dict()
|
"""Ensure the socket is created before use."""
|
||||||
self.router = None
|
if self.socket is None:
|
||||||
self.poller = None
|
self.socket = self._create_socket()
|
||||||
self.running = True
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""
|
"""
|
||||||
Connect to the server using the file name specified in the constructor.
|
Connect to the server using the file name specified in the constructor.
|
||||||
"""
|
"""
|
||||||
self.socket.connect(f"ipc://{self.file_name}")
|
pass
|
||||||
|
|
||||||
def start_server(self):
|
|
||||||
"""
|
|
||||||
Start the server using the file name specified in the constructor.
|
|
||||||
"""
|
|
||||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
|
||||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
|
||||||
self.socket.bind(f"ipc://{self.file_name}")
|
|
||||||
self.poller = zmq.Poller()
|
|
||||||
self.poller.register(self.socket, zmq.POLLIN)
|
|
||||||
|
|
||||||
def create_router(self):
|
|
||||||
"""
|
|
||||||
Create a ROUTER socket and bind it to the specified router path.
|
|
||||||
"""
|
|
||||||
self.router = self.context.socket(zmq.ROUTER)
|
|
||||||
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
|
||||||
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
|
||||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
|
||||||
self.router.bind(f"ipc://{self.router_path}")
|
|
||||||
zmq_client_logger.info(f"router path: {self.router_path}")
|
|
||||||
|
|
||||||
def send_json(self, data):
|
def send_json(self, data):
|
||||||
"""
|
"""
|
||||||
Send a JSON-serializable object over the socket.
|
Send a JSON-serializable object over the socket.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
self.socket.send_json(data)
|
self.socket.send_json(data)
|
||||||
|
|
||||||
def recv_json(self):
|
def recv_json(self):
|
||||||
"""
|
"""
|
||||||
Receive a JSON-serializable object from the socket.
|
Receive a JSON-serializable object from the socket.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
return self.socket.recv_json()
|
return self.socket.recv_json()
|
||||||
|
|
||||||
def send_pyobj(self, data):
|
def send_pyobj(self, data):
|
||||||
"""
|
"""
|
||||||
Send a Pickle-serializable object over the socket.
|
Send a Pickle-serializable object over the socket.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
self.socket.send_pyobj(data)
|
self.socket.send_pyobj(data)
|
||||||
|
|
||||||
def recv_pyobj(self):
|
def recv_pyobj(self):
|
||||||
"""
|
"""
|
||||||
Receive a Pickle-serializable object from the socket.
|
Receive a Pickle-serializable object from the socket.
|
||||||
"""
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
return self.socket.recv_pyobj()
|
return self.socket.recv_pyobj()
|
||||||
|
|
||||||
def pack_aggregated_data(self, data):
|
@abstractmethod
|
||||||
"""
|
def close(self):
|
||||||
Aggregate multiple responses into one and send them to the client.
|
pass
|
||||||
"""
|
|
||||||
result = data[0]
|
|
||||||
if len(data) > 1:
|
|
||||||
for response in data[1:]:
|
|
||||||
result.add(response)
|
|
||||||
result = msgpack.packb([result.to_dict()])
|
|
||||||
return result
|
|
||||||
|
|
||||||
def send_multipart(self, req_id, data):
|
|
||||||
"""
|
|
||||||
Send a multipart message to the router socket.
|
|
||||||
"""
|
|
||||||
if self.router is None:
|
|
||||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
|
||||||
|
|
||||||
while self.running:
|
class ZmqIpcClient(ZmqClientBase):
|
||||||
with self.mutex:
|
def __init__(self, name, mode):
|
||||||
if req_id not in self.req_dict:
|
self.name = name
|
||||||
try:
|
self.mode = mode
|
||||||
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
|
self.file_name = f"/dev/shm/{name}.socket"
|
||||||
req_id_str = request_id.decode("utf-8")
|
self.context = zmq.Context()
|
||||||
self.req_dict[req_id_str] = client
|
self.socket = self.context.socket(self.mode)
|
||||||
except zmq.Again:
|
|
||||||
time.sleep(0.001)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
if self.req_dict[req_id] == -1:
|
|
||||||
if data[-1].finished:
|
|
||||||
with self.mutex:
|
|
||||||
self.req_dict.pop(req_id, None)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
start_send = time.time()
|
|
||||||
if self.aggregate_send:
|
|
||||||
result = self.pack_aggregated_data(data)
|
|
||||||
else:
|
|
||||||
result = msgpack.packb([response.to_dict() for response in data])
|
|
||||||
self.router.send_multipart([self.req_dict[req_id], b"", result])
|
|
||||||
zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
|
||||||
except zmq.ZMQError as e:
|
|
||||||
zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
|
|
||||||
self.req_dict[req_id] = -1
|
|
||||||
except Exception as e:
|
|
||||||
zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
|
|
||||||
|
|
||||||
if data[-1].finished:
|
def _create_socket(self):
|
||||||
with self.mutex:
|
"""create and return a ZeroMQ socket."""
|
||||||
self.req_dict.pop(req_id, None)
|
self.context = zmq.Context()
|
||||||
zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")
|
return self.context.socket(self.mode)
|
||||||
|
|
||||||
def receive_json_once(self, block=False):
|
def connect(self):
|
||||||
"""
|
self._ensure_socket()
|
||||||
Receive a single message from the socket.
|
self.socket.connect(f"ipc://{self.file_name}")
|
||||||
"""
|
|
||||||
if self.socket is None or self.socket.closed:
|
|
||||||
return "zmp socket has closed", None
|
|
||||||
try:
|
|
||||||
flags = zmq.NOBLOCK if not block else 0
|
|
||||||
return None, self.socket.recv_json(flags=flags)
|
|
||||||
except zmq.Again:
|
|
||||||
return None, None
|
|
||||||
except Exception as e:
|
|
||||||
self.close()
|
|
||||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
|
||||||
return str(e), None
|
|
||||||
|
|
||||||
def receive_pyobj_once(self, block=False):
|
|
||||||
"""
|
|
||||||
Receive a single message from the socket.
|
|
||||||
"""
|
|
||||||
if self.socket is None or self.socket.closed:
|
|
||||||
return "zmp socket has closed", None
|
|
||||||
try:
|
|
||||||
flags = zmq.NOBLOCK if not block else 0
|
|
||||||
return None, self.socket.recv_pyobj(flags=flags)
|
|
||||||
except zmq.Again:
|
|
||||||
return None, None
|
|
||||||
except Exception as e:
|
|
||||||
self.close()
|
|
||||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
|
||||||
return str(e), None
|
|
||||||
|
|
||||||
def _clear_ipc(self, name):
|
|
||||||
"""
|
|
||||||
Remove the IPC file with the given name.
|
|
||||||
"""
|
|
||||||
if os.path.exists(name):
|
|
||||||
try:
|
|
||||||
os.remove(name)
|
|
||||||
except OSError as e:
|
|
||||||
zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
Close the socket and context, and remove the IPC files.
|
Close the socket and context.
|
||||||
"""
|
"""
|
||||||
if not self.running:
|
llm_logger.info("ZMQ client is closing connection...")
|
||||||
return
|
|
||||||
|
|
||||||
self.running = False
|
|
||||||
zmq_client_logger.info("Closing ZMQ connection...")
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "socket") and not self.socket.closed:
|
if self.socket is not None and not self.socket.closed:
|
||||||
|
self.socket.setsockopt(zmq.LINGER, 0)
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
if self.context is not None:
|
||||||
if self.router is not None and not self.router.closed:
|
|
||||||
self.router.close()
|
|
||||||
|
|
||||||
if not self.context.closed:
|
|
||||||
self.context.term()
|
self.context.term()
|
||||||
|
|
||||||
self._clear_ipc(self.file_name)
|
|
||||||
self._clear_ipc(self.router_path)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
|
llm_logger.warning(f"ZMQ client failed to close connection - {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close()
|
|
||||||
|
335
fastdeploy/inter_communicator/zmq_server.py
Normal file
335
fastdeploy/inter_communicator/zmq_server.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
|
from fastdeploy.utils import llm_logger
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqServerBase(ABC):
|
||||||
|
"""
|
||||||
|
ZmqServerBase
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cached_results = defaultdict(list)
|
||||||
|
self.response_token_lock = threading.Lock()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _create_socket(self):
|
||||||
|
"""Abstract method to create and return a ZeroMQ socket."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _ensure_socket(self):
|
||||||
|
"""Ensure the socket is created before use."""
|
||||||
|
if self.socket is None:
|
||||||
|
self.socket = self._create_socket()
|
||||||
|
|
||||||
|
def send_json(self, data):
|
||||||
|
"""
|
||||||
|
Send a JSON-serializable object over the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
self.socket.send_json(data)
|
||||||
|
|
||||||
|
def recv_json(self):
|
||||||
|
"""
|
||||||
|
Receive a JSON-serializable object from the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
return self.socket.recv_json()
|
||||||
|
|
||||||
|
def send_pyobj(self, data):
|
||||||
|
"""
|
||||||
|
Send a Pickle-serializable object over the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
self.socket.send_pyobj(data)
|
||||||
|
|
||||||
|
def recv_pyobj(self):
|
||||||
|
"""
|
||||||
|
Receive a Pickle-serializable object from the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
return self.socket.recv_pyobj()
|
||||||
|
|
||||||
|
def pack_aggregated_data(self, data):
|
||||||
|
"""
|
||||||
|
Aggregate multiple responses into one and send them to the client.
|
||||||
|
"""
|
||||||
|
result = data[0]
|
||||||
|
if len(data) > 1:
|
||||||
|
for response in data[1:]:
|
||||||
|
result.add(response)
|
||||||
|
result = msgpack.packb([result.to_dict()])
|
||||||
|
return result
|
||||||
|
|
||||||
|
def receive_json_once(self, block=False):
|
||||||
|
"""
|
||||||
|
Receive a single message from the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
if self.socket is None or self.socket.closed:
|
||||||
|
return "zmp socket has closed", None
|
||||||
|
try:
|
||||||
|
flags = zmq.NOBLOCK if not block else 0
|
||||||
|
return None, self.socket.recv_json(flags=flags)
|
||||||
|
except zmq.Again:
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
self.close()
|
||||||
|
llm_logger.warning(f"{e}")
|
||||||
|
return str(e), None
|
||||||
|
|
||||||
|
def receive_pyobj_once(self, block=False):
|
||||||
|
"""
|
||||||
|
Receive a single message from the socket.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
if self.socket is None or self.socket.closed:
|
||||||
|
return "zmp socket has closed", None
|
||||||
|
try:
|
||||||
|
flags = zmq.NOBLOCK if not block else 0
|
||||||
|
return None, self.socket.recv_pyobj(flags=flags)
|
||||||
|
except zmq.Again:
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
self.close()
|
||||||
|
llm_logger.warning(f"{e}")
|
||||||
|
return str(e), None
|
||||||
|
|
||||||
|
def recv_result_handle(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with self.response_token_lock:
|
||||||
|
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||||
|
req_id_str = request_id.decode("utf-8")
|
||||||
|
need_send_after_finished_inference = False
|
||||||
|
with self.mutex:
|
||||||
|
self.req_dict[req_id_str] = client
|
||||||
|
if req_id_str in self.cached_results:
|
||||||
|
if self.cached_results[req_id_str][-1][-1].finished:
|
||||||
|
need_send_after_finished_inference = True
|
||||||
|
if need_send_after_finished_inference:
|
||||||
|
self.send_response(req_id_str, [])
|
||||||
|
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
|
||||||
|
self.req_dict.pop(req_id_str, None)
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def send_response(self, req_id, data):
|
||||||
|
"""
|
||||||
|
Send generated token result to client.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
if self.socket is None:
|
||||||
|
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||||
|
new_data = []
|
||||||
|
has_result_handle = False
|
||||||
|
with self.mutex:
|
||||||
|
if req_id not in self.req_dict:
|
||||||
|
self.cached_results[req_id].append(data)
|
||||||
|
else:
|
||||||
|
has_result_handle = True
|
||||||
|
if req_id in self.cached_results:
|
||||||
|
for history_data in self.cached_results[req_id]:
|
||||||
|
new_data.extend(history_data)
|
||||||
|
llm_logger.info(
|
||||||
|
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
|
||||||
|
)
|
||||||
|
del self.cached_results[req_id]
|
||||||
|
if has_result_handle:
|
||||||
|
try:
|
||||||
|
new_data.extend(data)
|
||||||
|
start_send = time.time()
|
||||||
|
if self.aggregate_send:
|
||||||
|
result = self.pack_aggregated_data(new_data)
|
||||||
|
else:
|
||||||
|
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||||
|
with self.response_token_lock:
|
||||||
|
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||||
|
llm_logger.debug(
|
||||||
|
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||||
|
|
||||||
|
if data and data[-1].finished:
|
||||||
|
with self.mutex:
|
||||||
|
if req_id in self.req_dict:
|
||||||
|
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||||
|
self.req_dict.pop(req_id, None)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqIpcServer(ZmqServerBase):
|
||||||
|
"""
|
||||||
|
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, mode):
|
||||||
|
self.name = name
|
||||||
|
self.mode = mode
|
||||||
|
self.cached_results = defaultdict(list)
|
||||||
|
if mode == zmq.PULL:
|
||||||
|
self.file_name = f"/dev/shm/{name}.socket"
|
||||||
|
elif mode == zmq.ROUTER:
|
||||||
|
self.file_name = f"/dev/shm/router_{name}.ipc"
|
||||||
|
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||||
|
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||||
|
self.mutex = threading.Lock()
|
||||||
|
self.response_token_lock = threading.Lock()
|
||||||
|
self.req_dict = dict()
|
||||||
|
self.running = True
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self._create_socket()
|
||||||
|
|
||||||
|
def _create_socket(self):
|
||||||
|
"""create and return a ZeroMQ socket."""
|
||||||
|
self.socket = self.context.socket(self.mode)
|
||||||
|
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||||
|
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||||
|
self.socket.bind(f"ipc://{self.file_name}")
|
||||||
|
return self.socket
|
||||||
|
|
||||||
|
def _clear_ipc(self, name):
|
||||||
|
"""
|
||||||
|
Remove the IPC file with the given name.
|
||||||
|
"""
|
||||||
|
if os.path.exists(name):
|
||||||
|
try:
|
||||||
|
os.remove(name)
|
||||||
|
except OSError as e:
|
||||||
|
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close the socket and context, and remove the IPC files.
|
||||||
|
"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
llm_logger.info("ZMQ server is closing connection...")
|
||||||
|
try:
|
||||||
|
if self.socket is not None and not self.socket.closed:
|
||||||
|
self.socket.close()
|
||||||
|
if not self.context.closed:
|
||||||
|
self.context.term()
|
||||||
|
self._clear_ipc(self.file_name)
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.warning(f"ZMQ server failed to close connection - {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqTcpServer(ZmqServerBase):
|
||||||
|
"""
|
||||||
|
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, port, mode):
|
||||||
|
self.mode = mode
|
||||||
|
self.port = port
|
||||||
|
self.cached_results = defaultdict(list)
|
||||||
|
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||||
|
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||||
|
|
||||||
|
self.mutex = threading.Lock()
|
||||||
|
self.req_dict = dict()
|
||||||
|
self.running = True
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self._create_socket()
|
||||||
|
self.response_token_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _create_socket(self):
|
||||||
|
"""create and return a ZeroMQ socket."""
|
||||||
|
self.socket = self.context.socket(self.mode)
|
||||||
|
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||||
|
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||||
|
self.socket.bind(f"tcp://*:{self.port}")
|
||||||
|
return self.socket
|
||||||
|
|
||||||
|
def recv_control_cmd(self):
|
||||||
|
"""
|
||||||
|
Recieve control command from client
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
try:
|
||||||
|
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||||
|
task = msgpack.unpackb(task_data)
|
||||||
|
task_id_str = task["task_id"]
|
||||||
|
except zmq.Again:
|
||||||
|
return None
|
||||||
|
with self.mutex:
|
||||||
|
self.req_dict[task_id_str] = client
|
||||||
|
return task
|
||||||
|
|
||||||
|
def response_for_control_cmd(self, task_id, result):
|
||||||
|
"""
|
||||||
|
Send command result back to client.
|
||||||
|
"""
|
||||||
|
self._ensure_socket()
|
||||||
|
if self.socket is None:
|
||||||
|
raise RuntimeError("Router socket not created.")
|
||||||
|
try:
|
||||||
|
result = msgpack.packb(result)
|
||||||
|
self.socket.send_multipart([self.req_dict[task_id], b"", result])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
self.req_dict.pop(task_id, None)
|
||||||
|
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close the socket and context.
|
||||||
|
"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
llm_logger.info("ZMQ server is closing connection...")
|
||||||
|
try:
|
||||||
|
if self.socket is not None and not self.socket.closed:
|
||||||
|
self.socket.close()
|
||||||
|
if not self.context.closed:
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
llm_logger.warning(f"ZMQ server failed to close connection - {e}")
|
||||||
|
return
|
@@ -78,7 +78,8 @@ else:
|
|||||||
update_inputs_v1,
|
update_inputs_v1,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastdeploy.inter_communicator import ZmqClient
|
|
||||||
|
from fastdeploy.inter_communicator import ZmqIpcClient
|
||||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||||
|
|
||||||
@@ -160,7 +161,7 @@ def pre_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
|
def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
|
||||||
"""Split output_tokens and output"""
|
"""Split output_tokens and output"""
|
||||||
assert zmq_client is not None, "zmq_client should not be None"
|
assert zmq_client is not None, "zmq_client should not be None"
|
||||||
output_tokens = output_tokens.reshape([-1]).numpy()
|
output_tokens = output_tokens.reshape([-1]).numpy()
|
||||||
@@ -187,7 +188,7 @@ def post_process_normal(
|
|||||||
block_size: int = 64,
|
block_size: int = 64,
|
||||||
save_each_rank: bool = False,
|
save_each_rank: bool = False,
|
||||||
skip_save_output: bool = False,
|
skip_save_output: bool = False,
|
||||||
zmq_client: ZmqClient = None,
|
zmq_client: ZmqIpcClient = None,
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
"""Post-processing steps after completing a single token generation."""
|
"""Post-processing steps after completing a single token generation."""
|
||||||
# handle vl:
|
# handle vl:
|
||||||
@@ -389,7 +390,7 @@ def post_process(
|
|||||||
save_each_rank: bool = False,
|
save_each_rank: bool = False,
|
||||||
speculative_decoding: bool = False,
|
speculative_decoding: bool = False,
|
||||||
skip_save_output: bool = False,
|
skip_save_output: bool = False,
|
||||||
zmq_client: ZmqClient = None,
|
zmq_client: ZmqIpcClient = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Post-processing steps after completing a single token generation."""
|
"""Post-processing steps after completing a single token generation."""
|
||||||
if speculative_decoding:
|
if speculative_decoding:
|
||||||
|
@@ -29,7 +29,7 @@ import zmq
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcServer
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import llm_logger, spec_logger
|
from fastdeploy.utils import llm_logger, spec_logger
|
||||||
@@ -58,12 +58,11 @@ class TokenProcessor:
|
|||||||
self.split_connector = split_connector
|
self.split_connector = split_connector
|
||||||
|
|
||||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
|
||||||
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
||||||
self.zmq_server = ZmqClient(
|
self.zmq_server = ZmqIpcServer(
|
||||||
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
||||||
)
|
)
|
||||||
self.zmq_server.start_server()
|
|
||||||
self.zmq_server.create_router()
|
|
||||||
|
|
||||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||||
@@ -498,6 +497,7 @@ class TokenProcessor:
|
|||||||
metrics = RequestMetrics(
|
metrics = RequestMetrics(
|
||||||
arrival_time=task.arrival_time,
|
arrival_time=task.arrival_time,
|
||||||
inference_start_time=task.inference_start_time,
|
inference_start_time=task.inference_start_time,
|
||||||
|
model_execute_time=time.time() - task.inference_start_time,
|
||||||
first_token_time=time.time() - task.inference_start_time,
|
first_token_time=time.time() - task.inference_start_time,
|
||||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||||
@@ -510,6 +510,7 @@ class TokenProcessor:
|
|||||||
metrics = RequestMetrics(
|
metrics = RequestMetrics(
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
request_start_time=task.arrival_time,
|
request_start_time=task.arrival_time,
|
||||||
|
model_execute_time=time.time() - task.inference_start_time,
|
||||||
)
|
)
|
||||||
self.number_of_output_tokens += len(token_ids)
|
self.number_of_output_tokens += len(token_ids)
|
||||||
self._record_metrics(task, current_time, token_ids)
|
self._record_metrics(task, current_time, token_ids)
|
||||||
|
@@ -208,6 +208,9 @@ class LocalScheduler:
|
|||||||
"""
|
"""
|
||||||
return (token_num + block_size - 1) // block_size
|
return (token_num + block_size - 1) // block_size
|
||||||
|
|
||||||
|
def get_unhandled_request_num(self):
|
||||||
|
return len(self.ids) - self.ids_read_cursor
|
||||||
|
|
||||||
def get_requests(
|
def get_requests(
|
||||||
self,
|
self,
|
||||||
available_blocks,
|
available_blocks,
|
||||||
|
118
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
118
fastdeploy/splitwise/internal_adapter_utils.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# **Note**: Just for internal use
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator import ZmqTcpServer
|
||||||
|
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
|
||||||
|
from fastdeploy.utils import envs, get_logger
|
||||||
|
|
||||||
|
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
|
||||||
|
|
||||||
|
|
||||||
|
class InternalAdapter:
|
||||||
|
def __init__(self, cfg, engine, dp_rank):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.engine = engine
|
||||||
|
self.dp_rank = dp_rank
|
||||||
|
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
|
||||||
|
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
|
||||||
|
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
|
||||||
|
self.recv_external_instruct_thread = threading.Thread(
|
||||||
|
target=self._recv_external_module_control_instruct, daemon=True
|
||||||
|
)
|
||||||
|
self.recv_external_instruct_thread.start()
|
||||||
|
if cfg.splitwise_role != "mixed":
|
||||||
|
self.response_external_instruct_thread = threading.Thread(
|
||||||
|
target=self._response_external_module_control_instruct, daemon=True
|
||||||
|
)
|
||||||
|
self.response_external_instruct_thread.start()
|
||||||
|
|
||||||
|
def _get_current_server_info(self):
|
||||||
|
"""
|
||||||
|
Get resources information
|
||||||
|
"""
|
||||||
|
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||||
|
|
||||||
|
available_block_num = self.engine.resource_manager.available_block_num()
|
||||||
|
server_info = {
|
||||||
|
"splitwise_role": self.cfg.splitwise_role,
|
||||||
|
"block_size": int(self.cfg.cache_config.block_size),
|
||||||
|
"block_num": int(available_block_num),
|
||||||
|
"max_block_num": int(self.cfg.cache_config.total_block_num),
|
||||||
|
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
|
||||||
|
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
||||||
|
"max_batch_size": int(available_batch_size),
|
||||||
|
"max_input_token_num": self.cfg.max_model_len,
|
||||||
|
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||||
|
"available_batch": int(self.engine.resource_manager.available_batch()),
|
||||||
|
}
|
||||||
|
return server_info
|
||||||
|
|
||||||
|
def _recv_external_module_control_instruct(self):
|
||||||
|
"""
|
||||||
|
Receive a multipart message from the control cmd socket.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with self.response_lock:
|
||||||
|
task = self.recv_control_cmd_server.recv_control_cmd()
|
||||||
|
if task is None:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
logger.info(f"Recieve control task: {task}")
|
||||||
|
task_id_str = task["task_id"]
|
||||||
|
if task["cmd"] == "get_payload":
|
||||||
|
payload_info = self._get_current_server_info()
|
||||||
|
result = {"task_id": task_id_str, "result": payload_info}
|
||||||
|
logger.debug(f"Response for task: {task_id_str}")
|
||||||
|
with self.response_lock:
|
||||||
|
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||||
|
|
||||||
|
elif task["cmd"] == "get_metrics":
|
||||||
|
metrics_text = get_filtered_metrics(
|
||||||
|
[],
|
||||||
|
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
|
||||||
|
)
|
||||||
|
result = {"task_id": task_id_str, "result": metrics_text}
|
||||||
|
logger.debug(f"Response for task: {task_id_str}")
|
||||||
|
with self.response_lock:
|
||||||
|
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||||
|
elif task["cmd"] == "connect_rdma":
|
||||||
|
self.engine.engine_worker_queue.put_connect_rdma_task(task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
|
||||||
|
|
||||||
|
def _response_external_module_control_instruct(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
|
||||||
|
if result_data:
|
||||||
|
task_id_str = result_data["task_id"]
|
||||||
|
result = {"task_id": task_id_str, "result": result_data}
|
||||||
|
logger.info(f"Response for task: {task_id_str}")
|
||||||
|
with self.response_lock:
|
||||||
|
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
|
||||||
|
else:
|
||||||
|
time.sleep(0.001)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")
|
@@ -75,7 +75,7 @@ import zmq
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||||
from fastdeploy.inter_communicator import ZmqClient
|
from fastdeploy.inter_communicator import ZmqIpcClient
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||||
@@ -171,7 +171,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.zmq_client = None
|
self.zmq_client = None
|
||||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
logger.info(f"zmq client get_save_output_rank{local_rank}")
|
logger.info(f"zmq client get_save_output_rank{local_rank}")
|
||||||
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
|
self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
|
||||||
self.zmq_client.connect()
|
self.zmq_client.connect()
|
||||||
self.zmq_client.socket.SNDTIMEO = 3000
|
self.zmq_client.socket.SNDTIMEO = 3000
|
||||||
|
|
||||||
|
251
tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py
Normal file
251
tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||||
|
print("project_root", project_root)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# Read ports from environment variables; use default values if not set
|
||||||
|
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||||
|
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||||
|
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||||
|
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8234))
|
||||||
|
|
||||||
|
FD_ENABLE_INTERNAL_ADAPTER = int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "1"))
|
||||||
|
FD_ZMQ_RECV_REQUEST_SERVER_PORT = int(os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8204"))
|
||||||
|
FD_ZMQ_SEND_RESPONSE_SERVER_PORT = int(os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8205"))
|
||||||
|
FD_ZMQ_CONTROL_CMD_SERVER_PORTS = int(os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8206"))
|
||||||
|
FD_ZMQ_CONTROL_CMD_SERVER_PORT = FD_ZMQ_CONTROL_CMD_SERVER_PORTS
|
||||||
|
|
||||||
|
env["FD_ENABLE_INTERNAL_ADAPTER"] = str(FD_ENABLE_INTERNAL_ADAPTER)
|
||||||
|
env["FD_ZMQ_RECV_REQUEST_SERVER_PORT"] = str(FD_ZMQ_RECV_REQUEST_SERVER_PORT)
|
||||||
|
env["FD_ZMQ_SEND_RESPONSE_SERVER_PORT"] = str(FD_ZMQ_SEND_RESPONSE_SERVER_PORT)
|
||||||
|
env["FD_ZMQ_CONTROL_CMD_SERVER_PORTS"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORTS)
|
||||||
|
env["FD_ZMQ_CONTROL_CMD_SERVER_PORT"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORT)
|
||||||
|
|
||||||
|
# List of ports to clean before and after tests
|
||||||
|
PORTS_TO_CLEAN = [
|
||||||
|
FD_API_PORT,
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
FD_METRICS_PORT,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
FD_ZMQ_RECV_REQUEST_SERVER_PORT,
|
||||||
|
FD_ZMQ_SEND_RESPONSE_SERVER_PORT,
|
||||||
|
FD_ZMQ_CONTROL_CMD_SERVER_PORT,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zmq_req_client():
|
||||||
|
client = LLMReqClient("0.0.0.0", FD_ZMQ_RECV_REQUEST_SERVER_PORT, FD_ZMQ_SEND_RESPONSE_SERVER_PORT)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zmq_control_client():
|
||||||
|
client = LLMControlClient("0.0.0.0", FD_ZMQ_CONTROL_CMD_SERVER_PORT)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def is_port_open(host: str, port: int, timeout=1.0):
|
||||||
|
"""
|
||||||
|
Check if a TCP port is open on the given host.
|
||||||
|
Returns True if connection succeeds, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with socket.create_connection((host, port), timeout):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def kill_process_on_port(port: int):
|
||||||
|
"""
|
||||||
|
Kill processes that are listening on the given port.
|
||||||
|
Uses `lsof` to find process ids and sends SIGKILL.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
|
||||||
|
for pid in output.splitlines():
|
||||||
|
os.kill(int(pid), signal.SIGKILL)
|
||||||
|
print(f"Killed process on port {port}, pid={pid}")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
f"ps -ef -ww| grep {FD_CACHE_QUEUE_PORT} | grep -v grep", shell=True, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
pid = int(parts[1]) # ps -ef 的第二列是 PID
|
||||||
|
print(f"Killing PID: {pid}")
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to kill cache manager process: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def clean_ports():
|
||||||
|
"""
|
||||||
|
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
|
||||||
|
"""
|
||||||
|
for port in PORTS_TO_CLEAN:
|
||||||
|
kill_process_on_port(port)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def setup_and_run_server():
|
||||||
|
"""
|
||||||
|
Pytest fixture that runs once per test session:
|
||||||
|
- Cleans ports before tests
|
||||||
|
- Starts the API server as a subprocess
|
||||||
|
- Waits for server port to open (up to 30 seconds)
|
||||||
|
- Tears down server after all tests finish
|
||||||
|
"""
|
||||||
|
print("Pre-test port cleanup...")
|
||||||
|
clean_ports()
|
||||||
|
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
if base_path:
|
||||||
|
model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle")
|
||||||
|
else:
|
||||||
|
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
|
||||||
|
|
||||||
|
log_path = "server.log"
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"fastdeploy.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
model_path,
|
||||||
|
"--port",
|
||||||
|
str(FD_API_PORT),
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
"1",
|
||||||
|
"--engine-worker-queue-port",
|
||||||
|
str(FD_ENGINE_QUEUE_PORT),
|
||||||
|
"--metrics-port",
|
||||||
|
str(FD_METRICS_PORT),
|
||||||
|
"--cache-queue-port",
|
||||||
|
str(FD_CACHE_QUEUE_PORT),
|
||||||
|
"--max-model-len",
|
||||||
|
"32768",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--quantization",
|
||||||
|
"wint4",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start subprocess in new process group
|
||||||
|
# 清除log目录
|
||||||
|
if os.path.exists("log"):
|
||||||
|
shutil.rmtree("log")
|
||||||
|
with open(log_path, "w") as logfile:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=logfile,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True, # Enables killing full group via os.killpg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait up to 300 seconds for API server to be ready
|
||||||
|
for _ in range(300):
|
||||||
|
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||||
|
print(f"API server is up on port {FD_API_PORT}")
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGTERM)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to kill process group: {e}")
|
||||||
|
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||||
|
|
||||||
|
yield # Run tests
|
||||||
|
|
||||||
|
print("\n===== Post-test server cleanup... =====")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGTERM)
|
||||||
|
clean_ports()
|
||||||
|
print(f"API server (pid={process.pid}) terminated")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to terminate API server: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_and_response(zmq_req_client):
|
||||||
|
prompt_token_ids = [5300, 93956, 55791]
|
||||||
|
req_id = "test"
|
||||||
|
request = {
|
||||||
|
"req_id": req_id,
|
||||||
|
"request_id": req_id,
|
||||||
|
"min_tokens": 1,
|
||||||
|
"dp_rank": 0, # P实例 DP rank, 从当前环境变量里读取
|
||||||
|
"prompt_token_ids": prompt_token_ids,
|
||||||
|
"prompt_token_ids_len": len(prompt_token_ids),
|
||||||
|
"eos_token_ids": [2],
|
||||||
|
"stop_token_ids": [2],
|
||||||
|
"max_dec_len": 32 * 1024,
|
||||||
|
"max_tokens": 32 * 1024,
|
||||||
|
"min_dec_len": 1,
|
||||||
|
"arrival_time": time.time(),
|
||||||
|
"preprocess_start_time": time.time(),
|
||||||
|
"preprocess_end_time": time.time(),
|
||||||
|
"messages": [],
|
||||||
|
"temperature": 0.8,
|
||||||
|
"penalty_score": 1.0,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"top_p": 0.8,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
}
|
||||||
|
result_queue = queue.Queue()
|
||||||
|
zmq_req_client.start(result_queue)
|
||||||
|
zmq_req_client.send_request(request)
|
||||||
|
zmq_req_client.request_result(req_id)
|
||||||
|
has_is_end_result = False
|
||||||
|
while True:
|
||||||
|
result = result_queue.get()
|
||||||
|
if result[-1]["finished"]:
|
||||||
|
has_is_end_result = True
|
||||||
|
break
|
||||||
|
assert has_is_end_result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_cmd(zmq_control_client):
|
||||||
|
result = zmq_control_client.get_payload()
|
||||||
|
assert "unhandled_request_num" in result
|
||||||
|
result = zmq_control_client.get_metrics()
|
||||||
|
assert result is not None
|
121
tests/ci_use/EB_Lite_with_adapter/zmq_client.py
Normal file
121
tests/ci_use/EB_Lite_with_adapter/zmq_client.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
|
||||||
|
class LLMReqClient:
|
||||||
|
"""
|
||||||
|
LLM request client
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ip, send_req_server_port, recv_res_server_port):
|
||||||
|
self.ZMQ_SNDHWM = 64 * 1024
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.send_req_client = self.context.socket(zmq.PUSH)
|
||||||
|
self.recv_res_client = self.context.socket(zmq.DEALER)
|
||||||
|
self.send_req_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||||
|
self.send_req_client.setsockopt(zmq.SNDTIMEO, -1)
|
||||||
|
self.recv_res_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||||
|
self.recv_res_client.setsockopt(zmq.SNDTIMEO, -1)
|
||||||
|
self.send_req_client.connect(f"tcp://{ip}:{send_req_server_port}")
|
||||||
|
self.recv_res_client.connect(f"tcp://{ip}:{recv_res_server_port}")
|
||||||
|
self.need_exit = False
|
||||||
|
self.response_socket_lock = threading.Lock()
|
||||||
|
|
||||||
|
def send_request(self, req_data):
|
||||||
|
self.send_req_client.send_json(req_data)
|
||||||
|
|
||||||
|
def request_result(self, req_id):
|
||||||
|
with self.response_socket_lock:
|
||||||
|
print(f"request result data for {req_id}")
|
||||||
|
self.recv_res_client.send_multipart([b"", req_id.encode("utf-8")])
|
||||||
|
|
||||||
|
def consume_results(self, result_queue):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
with self.response_socket_lock:
|
||||||
|
frames = self.recv_res_client.recv_multipart(flags=zmq.NOBLOCK)
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
data = frames[-1]
|
||||||
|
response = msgpack.unpackb(data)
|
||||||
|
print(f"get result data {response}")
|
||||||
|
result_queue.put(response)
|
||||||
|
if self.need_exit:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"zmq client occured error {e} type: {type(e)} frames: {frames}")
|
||||||
|
|
||||||
|
def start(self, result_queue):
|
||||||
|
threading.Thread(target=self.consume_results, args=(result_queue,), daemon=True).start()
|
||||||
|
|
||||||
|
def exit(self):
|
||||||
|
print("exit")
|
||||||
|
self.need_exit = True
|
||||||
|
|
||||||
|
|
||||||
|
class LLMControlClient:
|
||||||
|
"""
|
||||||
|
LLM control client
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ip, port):
|
||||||
|
self.ZMQ_SNDHWM = 64 * 1024
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.control_client = self.context.socket(zmq.DEALER)
|
||||||
|
self.control_client.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||||
|
self.control_client.setsockopt(zmq.SNDTIMEO, -1)
|
||||||
|
self.control_client.connect(f"tcp://{ip}:{port}")
|
||||||
|
self.task_event = {}
|
||||||
|
self.result = {}
|
||||||
|
self.response_socket_lock = threading.Lock()
|
||||||
|
threading.Thread(target=self.recv_results, daemon=True).start()
|
||||||
|
|
||||||
|
def get_payload(self):
|
||||||
|
task_id = f"get_payload_{uuid.uuid4()}"
|
||||||
|
task = {"task_id": task_id, "cmd": "get_payload"}
|
||||||
|
self.task_event[task_id] = Event()
|
||||||
|
payload = msgpack.packb(task)
|
||||||
|
with self.response_socket_lock:
|
||||||
|
self.control_client.send_multipart([b"", payload])
|
||||||
|
self.task_event[task_id].wait()
|
||||||
|
result = self.result[task_id]
|
||||||
|
del self.result[task_id]
|
||||||
|
del self.task_event[task_id]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_metrics(self):
|
||||||
|
task_id = f"get_metrics_{uuid.uuid4()}"
|
||||||
|
task = {"task_id": task_id, "cmd": "get_metrics"}
|
||||||
|
self.task_event[task_id] = Event()
|
||||||
|
payload = msgpack.packb(task)
|
||||||
|
with self.response_socket_lock:
|
||||||
|
self.control_client.send_multipart([b"", payload])
|
||||||
|
self.task_event[task_id].wait()
|
||||||
|
result = self.result[task_id]
|
||||||
|
del self.result[task_id]
|
||||||
|
del self.task_event[task_id]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def recv_results(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
with self.response_socket_lock:
|
||||||
|
frames = self.control_client.recv_multipart(flags=zmq.NOBLOCK)
|
||||||
|
except zmq.Again:
|
||||||
|
time.sleep(0.001)
|
||||||
|
continue
|
||||||
|
data = frames[-1]
|
||||||
|
result = msgpack.unpackb(data)
|
||||||
|
task_id = result["task_id"]
|
||||||
|
self.result[task_id] = result["result"]
|
||||||
|
self.task_event[task_id].set()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"zmq client occured error {e} type: {type(e)} frames: {frames}")
|
Reference in New Issue
Block a user