[Feature] add dealer manager to reuse the connection (#3471)

* [BugFix] fix control signal release failed

* [BugFix] fix control signal release failed

* update

* update

* update

* [Feature] add dealer manager to reuse the connection

* fix

* fix

* fix

* fix

* fix

* fix

* Create test_dealer_connection_manager.py

* Delete test/entrypoints/openai directory

* Update test_dealer_connection_manager.py

* Update test_dealer_connection_manager.py
This commit is contained in:
ltd0924
2025-08-21 13:11:13 +08:00
committed by GitHub
parent 985b1265c3
commit 51f68ae593
7 changed files with 360 additions and 25 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import os
import time import time
import traceback import traceback
import uuid import uuid
@@ -22,6 +23,7 @@ import numpy as np
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.engine.config import ModelConfig from fastdeploy.engine.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.inter_communicator import IPCSignal, ZmqClient
@@ -91,6 +93,10 @@ class EngineClient:
suffix=pid, suffix=pid,
create=False, create=False,
) )
self.connection_manager = DealerConnectionManager(
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
)
self.connection_initialized = False
def create_zmq_client(self, model, mode): def create_zmq_client(self, model, mode):
""" """

View File

@@ -154,6 +154,7 @@ async def lifespan(app: FastAPI):
yield yield
# close zmq # close zmq
try: try:
await engine_client.connection_manager.close()
engine_client.zmq_client.close() engine_client.zmq_client.close()
from prometheus_client import multiprocess from prometheus_client import multiprocess

View File

@@ -20,10 +20,7 @@ import traceback
import uuid import uuid
from typing import List, Optional from typing import List, Optional
import aiozmq
import msgpack
import numpy as np import numpy as np
from aiozmq import zmq
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
@@ -62,6 +59,12 @@ class OpenAIServingChat:
else: else:
self.master_ip = self.master_ip.split(",")[0] self.master_ip = self.master_ip.split(",")[0]
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True
def _check_master(self): def _check_master(self):
if self.master_ip is None: if self.master_ip is None:
return True return True
@@ -180,14 +183,16 @@ class OpenAIServingChat:
choices=[], choices=[],
model=model_name, model=model_name,
) )
try: try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")]) dealer.write([b"", request_id.encode("utf-8")])
choices = [] choices = []
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0 current_waiting_time = 0
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
@@ -202,7 +207,6 @@ class OpenAIServingChat:
current_waiting_time = 0 current_waiting_time = 0
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
continue continue
response = msgpack.unpackb(raw_data[-1])
for res in response: for res in response:
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
@@ -353,9 +357,9 @@ class OpenAIServingChat:
) )
yield f"data: {error_data}\n\n" yield f"data: {error_data}\n\n"
finally: finally:
dealer.close() await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release() self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}") api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def chat_completion_full_generator( async def chat_completion_full_generator(
@@ -378,7 +382,8 @@ class OpenAIServingChat:
include_stop_str_in_output = request.include_stop_str_in_output include_stop_str_in_output = request.include_stop_str_in_output
try: try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")]) dealer.write([b"", request_id.encode("utf-8")])
final_res = None final_res = None
previous_num_tokens = 0 previous_num_tokens = 0
@@ -387,7 +392,7 @@ class OpenAIServingChat:
completion_token_ids = [] completion_token_ids = []
while True: while True:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0 current_waiting_time = 0
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
@@ -400,7 +405,6 @@ class OpenAIServingChat:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
response = msgpack.unpackb(raw_data[-1])
task_is_finished = False task_is_finished = False
for data in response: for data in response:
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
@@ -430,7 +434,7 @@ class OpenAIServingChat:
if task_is_finished: if task_is_finished:
break break
finally: finally:
dealer.close() await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release() self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}") api_server_logger.info(f"release {self.engine_client.semaphore.status()}")

View File

@@ -20,10 +20,7 @@ import traceback
import uuid import uuid
from typing import List, Optional from typing import List, Optional
import aiozmq
import msgpack
import numpy as np import numpy as np
from aiozmq import zmq
from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
@@ -53,6 +50,12 @@ class OpenAIServingCompletion:
else: else:
self.master_ip = self.master_ip.split(",")[0] self.master_ip = self.master_ip.split(",")[0]
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True
def _check_master(self): def _check_master(self):
if self.master_ip is None: if self.master_ip is None:
return True return True
@@ -185,7 +188,10 @@ class OpenAIServingCompletion:
try: try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)] request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
# create dealer # create dealer
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for rid in request_ids: for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")]) dealer.write([b"", rid.encode("utf-8")])
@@ -198,7 +204,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0 current_waiting_time = 0
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
@@ -210,7 +216,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0 current_waiting_time = 0
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
response = msgpack.unpackb(raw_data[-1])
for data in response: for data in response:
rid = int(data["request_id"].split("-")[-1]) rid = int(data["request_id"].split("-")[-1])
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
@@ -255,7 +261,7 @@ class OpenAIServingCompletion:
finally: finally:
self.engine_client.semaphore.release() self.engine_client.semaphore.release()
if dealer is not None: if dealer is not None:
dealer.close() await self.engine_client.connection_manager.cleanup_request(request_id)
async def _echo_back_prompt(self, request, res, idx): async def _echo_back_prompt(self, request, res, idx):
if res["outputs"].get("send_idx", -1) == 0 and request.echo: if res["outputs"].get("send_idx", -1) == 0 and request.echo:
@@ -288,7 +294,10 @@ class OpenAIServingCompletion:
Process the stream completion request. Process the stream completion request.
""" """
try: try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for i in range(num_choices): for i in range(num_choices):
req_id = f"{request_id}-{i}" req_id = f"{request_id}-{i}"
@@ -312,7 +321,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0 current_waiting_time = 0
while num_choices > 0: while num_choices > 0:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0 current_waiting_time = 0
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
@@ -325,7 +334,6 @@ class OpenAIServingCompletion:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
response = msgpack.unpackb(raw_data[-1])
for res in response: for res in response:
idx = int(res["request_id"].split("-")[-1]) idx = int(res["request_id"].split("-")[-1])
if res.get("error_code", 200) != 200: if res.get("error_code", 200) != 200:
@@ -453,9 +461,9 @@ class OpenAIServingCompletion:
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n" yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
finally: finally:
del request del request
self.engine_client.semaphore.release()
if dealer is not None: if dealer is not None:
dealer.close() await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
def request_output_to_completion_response( def request_output_to_completion_response(

View File

@@ -0,0 +1,159 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import heapq
import random
import aiozmq
import msgpack
import zmq
from fastdeploy.utils import api_server_logger
class DealerConnectionManager:
"""
Manager for dealer connections, supporting multiplexing and connection reuse
"""
def __init__(self, pid, max_connections=10):
self.pid = pid
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_map = {} # request_id -> response_queue
self.request_num = {} # request_id -> num_choices
self.lock = asyncio.Lock()
self.connection_tasks = []
self.running = False
async def initialize(self):
"""initialize all connections"""
self.running = True
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections")
async def _add_connection(self, index):
"""create a new connection and start listening task"""
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
)
async with self.lock:
self.connections.append(dealer)
self.connection_load.append(0)
heapq.heappush(self.connection_heap, (0, index))
# start listening
task = asyncio.create_task(self._listen_connection(dealer, index))
self.connection_tasks.append(task)
return True
except Exception as e:
api_server_logger.error(f"Failed to create dealer: {str(e)}")
return False
async def _listen_connection(self, dealer, conn_index):
"""
listen for messages from the dealer connection
"""
while self.running:
try:
raw_data = await dealer.read()
response = msgpack.unpackb(raw_data[-1])
request_id = response[-1]["request_id"]
if "cmpl" == request_id[:4]:
request_id = request_id.rsplit("-", 1)[0]
async with self.lock:
if request_id in self.request_map:
await self.request_map[request_id].put(response)
if response[-1]["finished"]:
self.request_num[request_id] -= 1
if self.request_num[request_id] == 0:
self._update_load(conn_index, -1)
except Exception as e:
api_server_logger.error(f"Listener error: {str(e)}")
break
def _update_load(self, conn_index, delta):
"""Update connection load and maintain the heap"""
self.connection_load[conn_index] += delta
heapq.heapify(self.connection_heap)
# For Debugging purposes
if random.random() < 0.01:
min_load = self.connection_heap[0][0] if self.connection_heap else 0
max_load = max(self.connection_load) if self.connection_load else 0
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
def _get_least_loaded_connection(self):
"""
Get the least loaded connection
"""
if not self.connection_heap:
return None
load, conn_index = self.connection_heap[0]
self._update_load(conn_index, 1)
return self.connections[conn_index]
async def get_connection(self, request_id, num_choices=1):
"""get a connection for the request"""
response_queue = asyncio.Queue()
async with self.lock:
self.request_map[request_id] = response_queue
self.request_num[request_id] = num_choices
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")
return dealer, response_queue
async def cleanup_request(self, request_id):
"""
clean up the request after it is finished
"""
async with self.lock:
if request_id in self.request_map:
del self.request_map[request_id]
del self.request_num[request_id]
async def close(self):
"""
close all connections and tasks
"""
self.running = False
for task in self.connection_tasks:
task.cancel()
async with self.lock:
for dealer in self.connections:
try:
dealer.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
self.request_map.clear()
api_server_logger.info("All connections and tasks closed")

View File

@@ -85,7 +85,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# set trace attribute job_id. # set trace attribute job_id.
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
# support max connections # support max connections
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
} }

View File

@@ -0,0 +1,157 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import unittest
from unittest.mock import AsyncMock, patch
import msgpack
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
class TestDealerConnectionManager(unittest.TestCase):
"""Test cases for DealerConnectionManager"""
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.manager = DealerConnectionManager(pid=1, max_connections=5)
def tearDown(self):
self.loop.run_until_complete(self.manager.close())
self.loop.close()
@patch("aiozmq.create_zmq_stream")
async def test_initialization(self, mock_create):
"""Test manager initialization creates connections"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
# Test initialization
await self.manager.initialize()
# Verify connections were created
self.assertEqual(len(self.manager.connections), 10)
self.assertEqual(len(self.manager.connection_load), 10)
self.assertEqual(len(self.manager.connection_tasks), 10)
# Verify connection tasks are running
for task in self.manager.connection_tasks:
self.assertFalse(task.done())
@patch("aiozmq.create_zmq_stream")
async def test_get_connection(self, mock_create):
"""Test getting a connection with load balancing"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
await self.manager.initialize()
# Get a connection
dealer, queue = await self.manager.get_connection("req1")
# Verify least loaded connection is returned
self.assertEqual(self.manager.connection_load[0], 1)
self.assertIsNotNone(dealer)
self.assertIsNotNone(queue)
self.assertIn("req1", self.manager.request_map)
@patch("aiozmq.create_zmq_stream")
async def test_connection_listening(self, mock_create):
"""Test connection listener handles responses"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
await self.manager.initialize()
# Setup test response
test_response = {"request_id": "req1", "finished": True}
mock_stream.read.return_value = [b"", msgpack.packb(test_response)]
# Simulate response
dealer, queue = await self.manager.get_connection("req1")
response = await queue.get()
# Verify response handling
self.assertEqual(response[-1]["request_id"], "req1")
self.assertEqual(self.manager.connection_load[0], 0) # Should be decremented after finish
@patch("aiozmq.create_zmq_stream")
async def test_request_cleanup(self, mock_create):
"""Test request cleanup removes request tracking"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
await self.manager.initialize()
await self.manager.get_connection("req1")
self.assertIn("req1", self.manager.request_map)
await self.manager.cleanup_request("req1")
self.assertNotIn("req1", self.manager.request_map)
@patch("aiozmq.create_zmq_stream")
async def test_multiple_requests(self, mock_create):
"""Test load balancing with multiple requests"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
await self.manager.initialize()
# Get multiple connections
connections = []
for i in range(1, 6):
dealer, queue = await self.manager.get_connection(f"req{i}")
connections.append((dealer, queue))
# Verify load is distributed
load_counts = [0] * 5
for i in range(5):
load_counts[i] = self.manager.connection_load[i]
self.assertEqual(sum(load_counts), 5)
self.assertTrue(all(1 <= load <= 2 for load in load_counts))
@patch("aiozmq.create_zmq_stream")
async def test_connection_failure(self, mock_create):
"""Test connection failure handling"""
mock_create.side_effect = Exception("Connection failed")
with self.assertLogs(level="ERROR") as log:
await self.manager._add_connection(0)
self.assertTrue(any("Failed to create dealer" in msg for msg in log.output))
self.assertEqual(len(self.manager.connections), 0)
@patch("aiozmq.create_zmq_stream")
async def test_close_manager(self, mock_create):
"""Test manager shutdown"""
mock_stream = AsyncMock()
mock_create.return_value = mock_stream
await self.manager.initialize()
# Verify connections exist
self.assertEqual(len(self.manager.connections), 5)
# Close manager
await self.manager.close()
# Verify cleanup
self.assertEqual(len(self.manager.connections), 0)
self.assertEqual(len(self.manager.request_map), 0)
for task in self.manager.connection_tasks:
self.assertTrue(task.cancelled())
if __name__ == "__main__":
unittest.main()