diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index cf1ebdd29..2fa5d8d0e 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os import time import traceback import uuid @@ -22,6 +23,7 @@ import numpy as np from fastdeploy import envs from fastdeploy.engine.config import ModelConfig +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import IPCSignal, ZmqClient @@ -91,6 +93,10 @@ class EngineClient: suffix=pid, create=False, ) + self.connection_manager = DealerConnectionManager( + pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) + ) + self.connection_initialized = False def create_zmq_client(self, model, mode): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index ca0b45e7f..a0962856d 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -154,6 +154,7 @@ async def lifespan(app: FastAPI): yield # close zmq try: + await engine_client.connection_manager.close() engine_client.zmq_client.close() from prometheus_client import multiprocess diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 5f9a99958..6ab11717f 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -20,10 +20,7 @@ import traceback import uuid from typing import List, Optional -import aiozmq -import msgpack import numpy as np -from aiozmq import zmq from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -62,6 +59,12 @@ class OpenAIServingChat: else: self.master_ip = self.master_ip.split(",")[0] + async def _ensure_connection_manager(self): + """ensure connection manager initialized""" + if not self.engine_client.connection_initialized: + await self.engine_client.connection_manager.initialize() + self.engine_client.connection_initialized = True + def _check_master(self): if self.master_ip is None: return True @@ -180,14 +183,16 @@ class OpenAIServingChat: choices=[], model=model_name, ) + try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) dealer.write([b"", request_id.encode("utf-8")]) choices = [] current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -202,7 +207,6 @@ class OpenAIServingChat: current_waiting_time = 0 await asyncio.sleep(0.01) continue - response = msgpack.unpackb(raw_data[-1]) for res in response: if res.get("error_code", 200) != 200: raise ValueError("{}".format(res["error_msg"])) @@ -353,9 +357,9 @@ class OpenAIServingChat: ) yield f"data: {error_data}\n\n" finally: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) self.engine_client.semaphore.release() - api_server_logger.info(f"release {self.engine_client.semaphore.status()}") + api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}") yield "data: [DONE]\n\n" async def chat_completion_full_generator( @@ -378,7 +382,8 @@ class OpenAIServingChat: include_stop_str_in_output = request.include_stop_str_in_output try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) dealer.write([b"", request_id.encode("utf-8")]) final_res = None previous_num_tokens = 0 @@ -387,7 +392,7 @@ class OpenAIServingChat: completion_token_ids = [] while True: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -400,7 +405,6 @@ class OpenAIServingChat: await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) task_is_finished = False for data in response: if data.get("error_code", 200) != 200: @@ -430,7 +434,7 @@ class OpenAIServingChat: if task_is_finished: break finally: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) self.engine_client.semaphore.release() api_server_logger.info(f"release {self.engine_client.semaphore.status()}") diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index c6ee86d2f..ec2f18076 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -20,10 +20,7 @@ import traceback import uuid from typing import List, Optional -import aiozmq -import msgpack import numpy as np -from aiozmq import zmq from fastdeploy.engine.request import RequestOutput from fastdeploy.entrypoints.openai.protocol import ( @@ -53,6 +50,12 @@ class OpenAIServingCompletion: else: self.master_ip = self.master_ip.split(",")[0] + async def _ensure_connection_manager(self): + """ensure connection manager initialized""" + if not self.engine_client.connection_initialized: + await self.engine_client.connection_manager.initialize() + self.engine_client.connection_initialized = True + def _check_master(self): if self.master_ip is None: return True @@ -185,7 +188,10 @@ class OpenAIServingCompletion: try: request_ids = [f"{request_id}-{i}" for i in range(num_choices)] # create dealer - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection( + request_id, num_choices + ) for rid in request_ids: dealer.write([b"", rid.encode("utf-8")]) @@ -198,7 +204,7 @@ class OpenAIServingCompletion: current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -210,7 +216,7 @@ class OpenAIServingCompletion: current_waiting_time = 0 await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) + for data in response: rid = int(data["request_id"].split("-")[-1]) if data.get("error_code", 200) != 200: @@ -255,7 +261,7 @@ class OpenAIServingCompletion: finally: self.engine_client.semaphore.release() if dealer is not None: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) async def _echo_back_prompt(self, request, res, idx): if res["outputs"].get("send_idx", -1) == 0 and request.echo: @@ -288,7 +294,10 @@ class OpenAIServingCompletion: Process the stream completion request. """ try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection( + request_id, num_choices + ) for i in range(num_choices): req_id = f"{request_id}-{i}" @@ -312,7 +321,7 @@ class OpenAIServingCompletion: current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -325,7 +334,6 @@ class OpenAIServingCompletion: await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) for res in response: idx = int(res["request_id"].split("-")[-1]) if res.get("error_code", 200) != 200: @@ -453,9 +461,9 @@ class OpenAIServingCompletion: yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n" finally: del request - self.engine_client.semaphore.release() if dealer is not None: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) + self.engine_client.semaphore.release() yield "data: [DONE]\n\n" def request_output_to_completion_response( diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py new file mode 100644 index 000000000..d33eb01c2 --- /dev/null +++ b/fastdeploy/entrypoints/openai/utils.py @@ -0,0 +1,159 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import heapq +import random + +import aiozmq +import msgpack +import zmq + +from fastdeploy.utils import api_server_logger + + +class DealerConnectionManager: + """ + Manager for dealer connections, supporting multiplexing and connection reuse + """ + + def __init__(self, pid, max_connections=10): + self.pid = pid + self.max_connections = max(max_connections, 10) + self.connections = [] + self.connection_load = [] + self.connection_heap = [] + self.request_map = {} # request_id -> response_queue + self.request_num = {} # request_id -> num_choices + self.lock = asyncio.Lock() + self.connection_tasks = [] + self.running = False + + async def initialize(self): + """initialize all connections""" + self.running = True + for index in range(self.max_connections): + await self._add_connection(index) + api_server_logger.info(f"Started {self.max_connections} connections") + + async def _add_connection(self, index): + """create a new connection and start listening task""" + try: + dealer = await aiozmq.create_zmq_stream( + zmq.DEALER, + connect=f"ipc:///dev/shm/router_{self.pid}.ipc", + ) + async with self.lock: + self.connections.append(dealer) + self.connection_load.append(0) + heapq.heappush(self.connection_heap, (0, index)) + + # start listening + task = asyncio.create_task(self._listen_connection(dealer, index)) + self.connection_tasks.append(task) + return True + except Exception as e: + api_server_logger.error(f"Failed to create dealer: {str(e)}") + return False + + async def _listen_connection(self, dealer, conn_index): + """ + listen for messages from the dealer connection + """ + while self.running: + try: + raw_data = await dealer.read() + response = msgpack.unpackb(raw_data[-1]) + request_id = response[-1]["request_id"] + if "cmpl" == request_id[:4]: + request_id = request_id.rsplit("-", 1)[0] + async with self.lock: + if request_id in self.request_map: + await self.request_map[request_id].put(response) + if response[-1]["finished"]: + self.request_num[request_id] -= 1 + if self.request_num[request_id] == 0: + self._update_load(conn_index, -1) + except Exception as e: + api_server_logger.error(f"Listener error: {str(e)}") + break + + def _update_load(self, conn_index, delta): + """Update connection load and maintain the heap""" + self.connection_load[conn_index] += delta + heapq.heapify(self.connection_heap) + + # For Debugging purposes + if random.random() < 0.01: + min_load = self.connection_heap[0][0] if self.connection_heap else 0 + max_load = max(self.connection_load) if self.connection_load else 0 + api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}") + + def _get_least_loaded_connection(self): + """ + Get the least loaded connection + """ + if not self.connection_heap: + return None + + load, conn_index = self.connection_heap[0] + self._update_load(conn_index, 1) + + return self.connections[conn_index] + + async def get_connection(self, request_id, num_choices=1): + """get a connection for the request""" + + response_queue = asyncio.Queue() + + async with self.lock: + self.request_map[request_id] = response_queue + self.request_num[request_id] = num_choices + dealer = self._get_least_loaded_connection() + if not dealer: + raise RuntimeError("No available connections") + + return dealer, response_queue + + async def cleanup_request(self, request_id): + """ + clean up the request after it is finished + """ + async with self.lock: + if request_id in self.request_map: + del self.request_map[request_id] + del self.request_num[request_id] + + async def close(self): + """ + close all connections and tasks + """ + self.running = False + + for task in self.connection_tasks: + task.cancel() + + async with self.lock: + for dealer in self.connections: + try: + dealer.close() + except: + pass + self.connections.clear() + self.connection_load.clear() + self.request_map.clear() + + api_server_logger.info("All connections and tasks closed") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 1c310961c..0155e260f 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -85,7 +85,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # set trace attribute job_id. "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections - "FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, + "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), } diff --git a/tests/entrypoints/openai/test_dealer_connection_manager.py b/tests/entrypoints/openai/test_dealer_connection_manager.py new file mode 100644 index 000000000..4ab1e4b99 --- /dev/null +++ b/tests/entrypoints/openai/test_dealer_connection_manager.py @@ -0,0 +1,157 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import unittest +from unittest.mock import AsyncMock, patch + +import msgpack + +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager + + +class TestDealerConnectionManager(unittest.TestCase): + """Test cases for DealerConnectionManager""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.manager = DealerConnectionManager(pid=1, max_connections=5) + + def tearDown(self): + self.loop.run_until_complete(self.manager.close()) + self.loop.close() + + @patch("aiozmq.create_zmq_stream") + async def test_initialization(self, mock_create): + """Test manager initialization creates connections""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + + # Test initialization + await self.manager.initialize() + + # Verify connections were created + self.assertEqual(len(self.manager.connections), 10) + self.assertEqual(len(self.manager.connection_load), 10) + self.assertEqual(len(self.manager.connection_tasks), 10) + + # Verify connection tasks are running + for task in self.manager.connection_tasks: + self.assertFalse(task.done()) + + @patch("aiozmq.create_zmq_stream") + async def test_get_connection(self, mock_create): + """Test getting a connection with load balancing""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Get a connection + dealer, queue = await self.manager.get_connection("req1") + + # Verify least loaded connection is returned + self.assertEqual(self.manager.connection_load[0], 1) + self.assertIsNotNone(dealer) + self.assertIsNotNone(queue) + self.assertIn("req1", self.manager.request_map) + + @patch("aiozmq.create_zmq_stream") + async def test_connection_listening(self, mock_create): + """Test connection listener handles responses""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Setup test response + test_response = {"request_id": "req1", "finished": True} + mock_stream.read.return_value = [b"", msgpack.packb(test_response)] + + # Simulate response + dealer, queue = await self.manager.get_connection("req1") + response = await queue.get() + + # Verify response handling + self.assertEqual(response[-1]["request_id"], "req1") + self.assertEqual(self.manager.connection_load[0], 0) # Should be decremented after finish + + @patch("aiozmq.create_zmq_stream") + async def test_request_cleanup(self, mock_create): + """Test request cleanup removes request tracking""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + await self.manager.get_connection("req1") + self.assertIn("req1", self.manager.request_map) + + await self.manager.cleanup_request("req1") + self.assertNotIn("req1", self.manager.request_map) + + @patch("aiozmq.create_zmq_stream") + async def test_multiple_requests(self, mock_create): + """Test load balancing with multiple requests""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Get multiple connections + connections = [] + for i in range(1, 6): + dealer, queue = await self.manager.get_connection(f"req{i}") + connections.append((dealer, queue)) + + # Verify load is distributed + load_counts = [0] * 5 + for i in range(5): + load_counts[i] = self.manager.connection_load[i] + + self.assertEqual(sum(load_counts), 5) + self.assertTrue(all(1 <= load <= 2 for load in load_counts)) + + @patch("aiozmq.create_zmq_stream") + async def test_connection_failure(self, mock_create): + """Test connection failure handling""" + mock_create.side_effect = Exception("Connection failed") + + with self.assertLogs(level="ERROR") as log: + await self.manager._add_connection(0) + self.assertTrue(any("Failed to create dealer" in msg for msg in log.output)) + + self.assertEqual(len(self.manager.connections), 0) + + @patch("aiozmq.create_zmq_stream") + async def test_close_manager(self, mock_create): + """Test manager shutdown""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Verify connections exist + self.assertEqual(len(self.manager.connections), 5) + + # Close manager + await self.manager.close() + + # Verify cleanup + self.assertEqual(len(self.manager.connections), 0) + self.assertEqual(len(self.manager.request_map), 0) + for task in self.manager.connection_tasks: + self.assertTrue(task.cancelled()) + + +if __name__ == "__main__": + unittest.main()