mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 21:02:24 +08:00

* [Feature] update ep * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix queue ports idx * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * Update engine.py * fix ci * fix some bug in mixed ep * add server fix and op fix * rm some log * fix code style * ltd fix * fix * fix * fix some bug * fix bug * fix bug * fix style * Update config.py * Update splitwise_connector.py * Update cache_messager.py * Update __init__.py * merge and fix * Update engine.py * Update common_engine.py * Update run_ci_xpu.sh * Update ernie_processor.py * Update ernie_processor.py --------- Co-authored-by: ltd0924 <ltd0924@sina.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
160 lines
5.4 KiB
Python
160 lines
5.4 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import asyncio
|
|
import heapq
|
|
import random
|
|
|
|
import aiozmq
|
|
import msgpack
|
|
import zmq
|
|
|
|
from fastdeploy.utils import api_server_logger
|
|
|
|
|
|
class DealerConnectionManager:
|
|
"""
|
|
Manager for dealer connections, supporting multiplexing and connection reuse
|
|
"""
|
|
|
|
def __init__(self, pid, max_connections=10):
|
|
self.pid = pid
|
|
self.max_connections = max(max_connections, 10)
|
|
self.connections = []
|
|
self.connection_load = []
|
|
self.connection_heap = []
|
|
self.request_map = {} # request_id -> response_queue
|
|
self.request_num = {} # request_id -> num_choices
|
|
self.lock = asyncio.Lock()
|
|
self.connection_tasks = []
|
|
self.running = False
|
|
|
|
async def initialize(self):
|
|
"""initialize all connections"""
|
|
self.running = True
|
|
for index in range(self.max_connections):
|
|
await self._add_connection(index)
|
|
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
|
|
|
|
async def _add_connection(self, index):
|
|
"""create a new connection and start listening task"""
|
|
try:
|
|
dealer = await aiozmq.create_zmq_stream(
|
|
zmq.DEALER,
|
|
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
|
|
)
|
|
async with self.lock:
|
|
self.connections.append(dealer)
|
|
self.connection_load.append(0)
|
|
heapq.heappush(self.connection_heap, (0, index))
|
|
|
|
# start listening
|
|
task = asyncio.create_task(self._listen_connection(dealer, index))
|
|
self.connection_tasks.append(task)
|
|
return True
|
|
except Exception as e:
|
|
api_server_logger.error(f"Failed to create dealer: {str(e)}")
|
|
return False
|
|
|
|
async def _listen_connection(self, dealer, conn_index):
|
|
"""
|
|
listen for messages from the dealer connection
|
|
"""
|
|
while self.running:
|
|
try:
|
|
raw_data = await dealer.read()
|
|
response = msgpack.unpackb(raw_data[-1])
|
|
request_id = response[-1]["request_id"]
|
|
if "cmpl" == request_id[:4]:
|
|
request_id = request_id.rsplit("-", 1)[0]
|
|
async with self.lock:
|
|
if request_id in self.request_map:
|
|
await self.request_map[request_id].put(response)
|
|
if response[-1]["finished"]:
|
|
self.request_num[request_id] -= 1
|
|
if self.request_num[request_id] == 0:
|
|
self._update_load(conn_index, -1)
|
|
except Exception as e:
|
|
api_server_logger.error(f"Listener error: {str(e)}")
|
|
break
|
|
|
|
def _update_load(self, conn_index, delta):
|
|
"""Update connection load and maintain the heap"""
|
|
self.connection_load[conn_index] += delta
|
|
heapq.heapify(self.connection_heap)
|
|
|
|
# For Debugging purposes
|
|
if random.random() < 0.01:
|
|
min_load = self.connection_heap[0][0] if self.connection_heap else 0
|
|
max_load = max(self.connection_load) if self.connection_load else 0
|
|
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
|
|
|
|
def _get_least_loaded_connection(self):
|
|
"""
|
|
Get the least loaded connection
|
|
"""
|
|
if not self.connection_heap:
|
|
return None
|
|
|
|
load, conn_index = self.connection_heap[0]
|
|
self._update_load(conn_index, 1)
|
|
|
|
return self.connections[conn_index]
|
|
|
|
async def get_connection(self, request_id, num_choices=1):
|
|
"""get a connection for the request"""
|
|
|
|
response_queue = asyncio.Queue()
|
|
|
|
async with self.lock:
|
|
self.request_map[request_id] = response_queue
|
|
self.request_num[request_id] = num_choices
|
|
dealer = self._get_least_loaded_connection()
|
|
if not dealer:
|
|
raise RuntimeError("No available connections")
|
|
|
|
return dealer, response_queue
|
|
|
|
async def cleanup_request(self, request_id):
|
|
"""
|
|
clean up the request after it is finished
|
|
"""
|
|
async with self.lock:
|
|
if request_id in self.request_map:
|
|
del self.request_map[request_id]
|
|
del self.request_num[request_id]
|
|
|
|
async def close(self):
|
|
"""
|
|
close all connections and tasks
|
|
"""
|
|
self.running = False
|
|
|
|
for task in self.connection_tasks:
|
|
task.cancel()
|
|
|
|
async with self.lock:
|
|
for dealer in self.connections:
|
|
try:
|
|
dealer.close()
|
|
except:
|
|
pass
|
|
self.connections.clear()
|
|
self.connection_load.clear()
|
|
self.request_map.clear()
|
|
|
|
api_server_logger.info("All connections and tasks closed")
|