mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] add dealer manager to reuse the connection (#3471)
* [BugFix] fix control signal release failed * [BugFix] fix control signal release failed * update * update * update * [Feature] add dealer manager to reuse the connection * fix * fix * fix * fix * fix * fix * Create test_dealer_connection_manager.py * Delete test/entrypoints/openai directory * Update test_dealer_connection_manager.py * Update test_dealer_connection_manager.py
This commit is contained in:
159
fastdeploy/entrypoints/openai/utils.py
Normal file
159
fastdeploy/entrypoints/openai/utils.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import random
|
||||
|
||||
import aiozmq
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
|
||||
class DealerConnectionManager:
|
||||
"""
|
||||
Manager for dealer connections, supporting multiplexing and connection reuse
|
||||
"""
|
||||
|
||||
def __init__(self, pid, max_connections=10):
|
||||
self.pid = pid
|
||||
self.max_connections = max(max_connections, 10)
|
||||
self.connections = []
|
||||
self.connection_load = []
|
||||
self.connection_heap = []
|
||||
self.request_map = {} # request_id -> response_queue
|
||||
self.request_num = {} # request_id -> num_choices
|
||||
self.lock = asyncio.Lock()
|
||||
self.connection_tasks = []
|
||||
self.running = False
|
||||
|
||||
async def initialize(self):
|
||||
"""initialize all connections"""
|
||||
self.running = True
|
||||
for index in range(self.max_connections):
|
||||
await self._add_connection(index)
|
||||
api_server_logger.info(f"Started {self.max_connections} connections")
|
||||
|
||||
async def _add_connection(self, index):
|
||||
"""create a new connection and start listening task"""
|
||||
try:
|
||||
dealer = await aiozmq.create_zmq_stream(
|
||||
zmq.DEALER,
|
||||
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
|
||||
)
|
||||
async with self.lock:
|
||||
self.connections.append(dealer)
|
||||
self.connection_load.append(0)
|
||||
heapq.heappush(self.connection_heap, (0, index))
|
||||
|
||||
# start listening
|
||||
task = asyncio.create_task(self._listen_connection(dealer, index))
|
||||
self.connection_tasks.append(task)
|
||||
return True
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Failed to create dealer: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _listen_connection(self, dealer, conn_index):
|
||||
"""
|
||||
listen for messages from the dealer connection
|
||||
"""
|
||||
while self.running:
|
||||
try:
|
||||
raw_data = await dealer.read()
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
request_id = response[-1]["request_id"]
|
||||
if "cmpl" == request_id[:4]:
|
||||
request_id = request_id.rsplit("-", 1)[0]
|
||||
async with self.lock:
|
||||
if request_id in self.request_map:
|
||||
await self.request_map[request_id].put(response)
|
||||
if response[-1]["finished"]:
|
||||
self.request_num[request_id] -= 1
|
||||
if self.request_num[request_id] == 0:
|
||||
self._update_load(conn_index, -1)
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Listener error: {str(e)}")
|
||||
break
|
||||
|
||||
def _update_load(self, conn_index, delta):
|
||||
"""Update connection load and maintain the heap"""
|
||||
self.connection_load[conn_index] += delta
|
||||
heapq.heapify(self.connection_heap)
|
||||
|
||||
# For Debugging purposes
|
||||
if random.random() < 0.01:
|
||||
min_load = self.connection_heap[0][0] if self.connection_heap else 0
|
||||
max_load = max(self.connection_load) if self.connection_load else 0
|
||||
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
|
||||
|
||||
def _get_least_loaded_connection(self):
|
||||
"""
|
||||
Get the least loaded connection
|
||||
"""
|
||||
if not self.connection_heap:
|
||||
return None
|
||||
|
||||
load, conn_index = self.connection_heap[0]
|
||||
self._update_load(conn_index, 1)
|
||||
|
||||
return self.connections[conn_index]
|
||||
|
||||
async def get_connection(self, request_id, num_choices=1):
|
||||
"""get a connection for the request"""
|
||||
|
||||
response_queue = asyncio.Queue()
|
||||
|
||||
async with self.lock:
|
||||
self.request_map[request_id] = response_queue
|
||||
self.request_num[request_id] = num_choices
|
||||
dealer = self._get_least_loaded_connection()
|
||||
if not dealer:
|
||||
raise RuntimeError("No available connections")
|
||||
|
||||
return dealer, response_queue
|
||||
|
||||
async def cleanup_request(self, request_id):
|
||||
"""
|
||||
clean up the request after it is finished
|
||||
"""
|
||||
async with self.lock:
|
||||
if request_id in self.request_map:
|
||||
del self.request_map[request_id]
|
||||
del self.request_num[request_id]
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
close all connections and tasks
|
||||
"""
|
||||
self.running = False
|
||||
|
||||
for task in self.connection_tasks:
|
||||
task.cancel()
|
||||
|
||||
async with self.lock:
|
||||
for dealer in self.connections:
|
||||
try:
|
||||
dealer.close()
|
||||
except:
|
||||
pass
|
||||
self.connections.clear()
|
||||
self.connection_load.clear()
|
||||
self.request_map.clear()
|
||||
|
||||
api_server_logger.info("All connections and tasks closed")
|
Reference in New Issue
Block a user