[Feature] add dealer manager to reuse the connection (#3471)

* [BugFix] fix control signal release failed

* [BugFix] fix control signal release failed

* update

* update

* update

* [Feature] add dealer manager to reuse the connection

* fix

* fix

* fix

* fix

* fix

* fix

* Create test_dealer_connection_manager.py

* Delete test/entrypoints/openai directory

* Update test_dealer_connection_manager.py

* Update test_dealer_connection_manager.py
This commit is contained in:
ltd0924
2025-08-21 13:11:13 +08:00
committed by GitHub
parent 985b1265c3
commit 51f68ae593
7 changed files with 360 additions and 25 deletions

View File

@@ -154,6 +154,7 @@ async def lifespan(app: FastAPI):
yield
# close zmq
try:
await engine_client.connection_manager.close()
engine_client.zmq_client.close()
from prometheus_client import multiprocess

View File

@@ -20,10 +20,7 @@ import traceback
import uuid
from typing import List, Optional
import aiozmq
import msgpack
import numpy as np
from aiozmq import zmq
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -62,6 +59,12 @@ class OpenAIServingChat:
else:
self.master_ip = self.master_ip.split(",")[0]
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True
def _check_master(self):
if self.master_ip is None:
return True
@@ -180,14 +183,16 @@ class OpenAIServingChat:
choices=[],
model=model_name,
)
try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
choices = []
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
@@ -202,7 +207,6 @@ class OpenAIServingChat:
current_waiting_time = 0
await asyncio.sleep(0.01)
continue
response = msgpack.unpackb(raw_data[-1])
for res in response:
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
@@ -353,9 +357,9 @@ class OpenAIServingChat:
)
yield f"data: {error_data}\n\n"
finally:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
@@ -378,7 +382,8 @@ class OpenAIServingChat:
include_stop_str_in_output = request.include_stop_str_in_output
try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
final_res = None
previous_num_tokens = 0
@@ -387,7 +392,7 @@ class OpenAIServingChat:
completion_token_ids = []
while True:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
@@ -400,7 +405,6 @@ class OpenAIServingChat:
await asyncio.sleep(0.1)
continue
response = msgpack.unpackb(raw_data[-1])
task_is_finished = False
for data in response:
if data.get("error_code", 200) != 200:
@@ -430,7 +434,7 @@ class OpenAIServingChat:
if task_is_finished:
break
finally:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")

View File

@@ -20,10 +20,7 @@ import traceback
import uuid
from typing import List, Optional
import aiozmq
import msgpack
import numpy as np
from aiozmq import zmq
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
@@ -53,6 +50,12 @@ class OpenAIServingCompletion:
else:
self.master_ip = self.master_ip.split(",")[0]
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True
def _check_master(self):
if self.master_ip is None:
return True
@@ -185,7 +188,10 @@ class OpenAIServingCompletion:
try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
# create dealer
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
@@ -198,7 +204,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
@@ -210,7 +216,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0
await asyncio.sleep(0.1)
continue
response = msgpack.unpackb(raw_data[-1])
for data in response:
rid = int(data["request_id"].split("-")[-1])
if data.get("error_code", 200) != 200:
@@ -255,7 +261,7 @@ class OpenAIServingCompletion:
finally:
self.engine_client.semaphore.release()
if dealer is not None:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
async def _echo_back_prompt(self, request, res, idx):
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
@@ -288,7 +294,10 @@ class OpenAIServingCompletion:
Process the stream completion request.
"""
try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for i in range(num_choices):
req_id = f"{request_id}-{i}"
@@ -312,7 +321,7 @@ class OpenAIServingCompletion:
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
@@ -325,7 +334,6 @@ class OpenAIServingCompletion:
await asyncio.sleep(0.1)
continue
response = msgpack.unpackb(raw_data[-1])
for res in response:
idx = int(res["request_id"].split("-")[-1])
if res.get("error_code", 200) != 200:
@@ -453,9 +461,9 @@ class OpenAIServingCompletion:
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
finally:
del request
self.engine_client.semaphore.release()
if dealer is not None:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
yield "data: [DONE]\n\n"
def request_output_to_completion_response(

View File

@@ -0,0 +1,159 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import heapq
import random
import aiozmq
import msgpack
import zmq
from fastdeploy.utils import api_server_logger
class DealerConnectionManager:
"""
Manager for dealer connections, supporting multiplexing and connection reuse
"""
def __init__(self, pid, max_connections=10):
self.pid = pid
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_map = {} # request_id -> response_queue
self.request_num = {} # request_id -> num_choices
self.lock = asyncio.Lock()
self.connection_tasks = []
self.running = False
async def initialize(self):
"""initialize all connections"""
self.running = True
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections")
async def _add_connection(self, index):
"""create a new connection and start listening task"""
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
)
async with self.lock:
self.connections.append(dealer)
self.connection_load.append(0)
heapq.heappush(self.connection_heap, (0, index))
# start listening
task = asyncio.create_task(self._listen_connection(dealer, index))
self.connection_tasks.append(task)
return True
except Exception as e:
api_server_logger.error(f"Failed to create dealer: {str(e)}")
return False
async def _listen_connection(self, dealer, conn_index):
"""
listen for messages from the dealer connection
"""
while self.running:
try:
raw_data = await dealer.read()
response = msgpack.unpackb(raw_data[-1])
request_id = response[-1]["request_id"]
if "cmpl" == request_id[:4]:
request_id = request_id.rsplit("-", 1)[0]
async with self.lock:
if request_id in self.request_map:
await self.request_map[request_id].put(response)
if response[-1]["finished"]:
self.request_num[request_id] -= 1
if self.request_num[request_id] == 0:
self._update_load(conn_index, -1)
except Exception as e:
api_server_logger.error(f"Listener error: {str(e)}")
break
def _update_load(self, conn_index, delta):
"""Update connection load and maintain the heap"""
self.connection_load[conn_index] += delta
heapq.heapify(self.connection_heap)
# For Debugging purposes
if random.random() < 0.01:
min_load = self.connection_heap[0][0] if self.connection_heap else 0
max_load = max(self.connection_load) if self.connection_load else 0
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
def _get_least_loaded_connection(self):
"""
Get the least loaded connection
"""
if not self.connection_heap:
return None
load, conn_index = self.connection_heap[0]
self._update_load(conn_index, 1)
return self.connections[conn_index]
async def get_connection(self, request_id, num_choices=1):
"""get a connection for the request"""
response_queue = asyncio.Queue()
async with self.lock:
self.request_map[request_id] = response_queue
self.request_num[request_id] = num_choices
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")
return dealer, response_queue
async def cleanup_request(self, request_id):
"""
clean up the request after it is finished
"""
async with self.lock:
if request_id in self.request_map:
del self.request_map[request_id]
del self.request_num[request_id]
async def close(self):
"""
close all connections and tasks
"""
self.running = False
for task in self.connection_tasks:
task.cancel()
async with self.lock:
for dealer in self.connections:
try:
dealer.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
self.request_map.clear()
api_server_logger.info("All connections and tasks closed")